forked from MirrorHub/synapse
		
	Merge branch 'develop' of https://github.com/matrix-org/synapse into release-v0.27.0
This commit is contained in:
		
				commit
				
					
						5e785d4d5b
					
				
			
		
					 35 changed files with 729 additions and 187 deletions
				
			
		UPGRADE.rst
docs
scripts
synapse
api
app
federation/transport
groups
handlers
http
replication/tcp
rest/client
state.pystorage
__init__.pyclient_ips.pyevents.pygroup_server.pyprepare_database.pyroom.pyroommember.py
schema/delta
search.pyuser_directory.pyutil
tests/util/caches
							
								
								
									
										12
									
								
								UPGRADE.rst
									
										
									
									
									
								
							
							
						
						
									
										12
									
								
								UPGRADE.rst
									
										
									
									
									
								
							|  | @ -48,6 +48,18 @@ returned by the Client-Server API: | |||
|     # configured on port 443. | ||||
|     curl -kv https://<host.name>/_matrix/client/versions 2>&1 | grep "Server:" | ||||
| 
 | ||||
| Upgrading to $NEXT_VERSION | ||||
| ==================== | ||||
| 
 | ||||
| This release expands the anonymous usage stats sent if the opt-in | ||||
| ``report_stats`` configuration is set to ``true``. We now capture RSS memory  | ||||
| and cpu use at a very coarse level. This requires administrators to install | ||||
| the optional ``psutil`` python module. | ||||
| 
 | ||||
| We would appreciate it if you could assist by ensuring this module is available | ||||
| and ``report_stats`` is enabled. This will let us see if performance changes to | ||||
| synapse are having an impact to the general community. | ||||
| 
 | ||||
| Upgrading to v0.15.0 | ||||
| ==================== | ||||
| 
 | ||||
|  |  | |||
|  | @ -55,7 +55,12 @@ synapse process.) | |||
| 
 | ||||
| You then create a set of configs for the various worker processes.  These | ||||
| should be worker configuration files, and should be stored in a dedicated | ||||
| subdirectory, to allow synctl to manipulate them. | ||||
| subdirectory, to allow synctl to manipulate them. An additional configuration | ||||
| for the master synapse process will need to be created because the process will | ||||
| not be started automatically. That configuration should look like this:: | ||||
| 
 | ||||
|     worker_app: synapse.app.homeserver | ||||
|     daemonize: true | ||||
| 
 | ||||
| Each worker configuration file inherits the configuration of the main homeserver | ||||
| configuration file.  You can then override configuration specific to that worker, | ||||
|  | @ -230,9 +235,11 @@ file. For example:: | |||
| ``synapse.app.event_creator`` | ||||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
| 
 | ||||
| Handles non-state event creation. It can handle REST endpoints matching:: | ||||
| Handles some event creation. It can handle REST endpoints matching:: | ||||
| 
 | ||||
|     ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send | ||||
|     ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$ | ||||
|     ^/_matrix/client/(api/v1|r0|unstable)/join/ | ||||
| 
 | ||||
| It will create events locally and then send them on to the main synapse | ||||
| instance to be persisted and handled. | ||||
|  |  | |||
|  | @ -1,6 +1,7 @@ | |||
| #!/usr/bin/env python | ||||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2015, 2016 OpenMarket Ltd | ||||
| # Copyright 2018 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
|  | @ -250,6 +251,12 @@ class Porter(object): | |||
|     @defer.inlineCallbacks | ||||
|     def handle_table(self, table, postgres_size, table_size, forward_chunk, | ||||
|                      backward_chunk): | ||||
|         logger.info( | ||||
|             "Table %s: %i/%i (rows %i-%i) already ported", | ||||
|             table, postgres_size, table_size, | ||||
|             backward_chunk+1, forward_chunk-1, | ||||
|         ) | ||||
| 
 | ||||
|         if not table_size: | ||||
|             return | ||||
| 
 | ||||
|  | @ -467,31 +474,10 @@ class Porter(object): | |||
|             self.progress.set_state("Preparing PostgreSQL") | ||||
|             self.setup_db(postgres_config, postgres_engine) | ||||
| 
 | ||||
|             # Step 2. Get tables. | ||||
|             self.progress.set_state("Fetching tables") | ||||
|             sqlite_tables = yield self.sqlite_store._simple_select_onecol( | ||||
|                 table="sqlite_master", | ||||
|                 keyvalues={ | ||||
|                     "type": "table", | ||||
|                 }, | ||||
|                 retcol="name", | ||||
|             ) | ||||
| 
 | ||||
|             postgres_tables = yield self.postgres_store._simple_select_onecol( | ||||
|                 table="information_schema.tables", | ||||
|                 keyvalues={}, | ||||
|                 retcol="distinct table_name", | ||||
|             ) | ||||
| 
 | ||||
|             tables = set(sqlite_tables) & set(postgres_tables) | ||||
| 
 | ||||
|             self.progress.set_state("Creating tables") | ||||
| 
 | ||||
|             logger.info("Found %d tables", len(tables)) | ||||
| 
 | ||||
|             self.progress.set_state("Creating port tables") | ||||
|             def create_port_table(txn): | ||||
|                 txn.execute( | ||||
|                     "CREATE TABLE port_from_sqlite3 (" | ||||
|                     "CREATE TABLE IF NOT EXISTS port_from_sqlite3 (" | ||||
|                     " table_name varchar(100) NOT NULL UNIQUE," | ||||
|                     " forward_rowid bigint NOT NULL," | ||||
|                     " backward_rowid bigint NOT NULL" | ||||
|  | @ -517,18 +503,33 @@ class Porter(object): | |||
|                     "alter_table", alter_table | ||||
|                 ) | ||||
|             except Exception as e: | ||||
|                 logger.info("Failed to create port table: %s", e) | ||||
|                 pass | ||||
| 
 | ||||
|             try: | ||||
|                 yield self.postgres_store.runInteraction( | ||||
|                     "create_port_table", create_port_table | ||||
|                 ) | ||||
|             except Exception as e: | ||||
|                 logger.info("Failed to create port table: %s", e) | ||||
|             yield self.postgres_store.runInteraction( | ||||
|                 "create_port_table", create_port_table | ||||
|             ) | ||||
| 
 | ||||
|             self.progress.set_state("Setting up") | ||||
|             # Step 2. Get tables. | ||||
|             self.progress.set_state("Fetching tables") | ||||
|             sqlite_tables = yield self.sqlite_store._simple_select_onecol( | ||||
|                 table="sqlite_master", | ||||
|                 keyvalues={ | ||||
|                     "type": "table", | ||||
|                 }, | ||||
|                 retcol="name", | ||||
|             ) | ||||
| 
 | ||||
|             # Set up tables. | ||||
|             postgres_tables = yield self.postgres_store._simple_select_onecol( | ||||
|                 table="information_schema.tables", | ||||
|                 keyvalues={}, | ||||
|                 retcol="distinct table_name", | ||||
|             ) | ||||
| 
 | ||||
|             tables = set(sqlite_tables) & set(postgres_tables) | ||||
|             logger.info("Found %d tables", len(tables)) | ||||
| 
 | ||||
|             # Step 3. Figure out what still needs copying | ||||
|             self.progress.set_state("Checking on port progress") | ||||
|             setup_res = yield defer.gatherResults( | ||||
|                 [ | ||||
|                     self.setup_table(table) | ||||
|  | @ -539,7 +540,8 @@ class Porter(object): | |||
|                 consumeErrors=True, | ||||
|             ) | ||||
| 
 | ||||
|             # Process tables. | ||||
|             # Step 4. Do the copying. | ||||
|             self.progress.set_state("Copying to postgres") | ||||
|             yield defer.gatherResults( | ||||
|                 [ | ||||
|                     self.handle_table(*res) | ||||
|  | @ -548,6 +550,9 @@ class Porter(object): | |||
|                 consumeErrors=True, | ||||
|             ) | ||||
| 
 | ||||
|             # Step 5. Do final post-processing | ||||
|             yield self._setup_state_group_id_seq() | ||||
| 
 | ||||
|             self.progress.done() | ||||
|         except: | ||||
|             global end_error_exec_info | ||||
|  | @ -707,6 +712,16 @@ class Porter(object): | |||
| 
 | ||||
|         defer.returnValue((done, remaining + done)) | ||||
| 
 | ||||
|     def _setup_state_group_id_seq(self): | ||||
|         def r(txn): | ||||
|             txn.execute("SELECT MAX(id) FROM state_groups") | ||||
|             next_id = txn.fetchone()[0]+1 | ||||
|             txn.execute( | ||||
|                 "ALTER SEQUENCE state_group_id_seq RESTART WITH %s", | ||||
|                 (next_id,), | ||||
|             ) | ||||
|         return self.postgres_store.runInteraction("setup_state_group_id_seq", r) | ||||
| 
 | ||||
| 
 | ||||
| ############################################## | ||||
| ###### The following is simply UI stuff ###### | ||||
|  |  | |||
|  | @ -15,9 +15,10 @@ | |||
| 
 | ||||
| """Contains exceptions and error codes.""" | ||||
| 
 | ||||
| import json | ||||
| import logging | ||||
| 
 | ||||
| import simplejson as json | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -48,6 +48,7 @@ from synapse.server import HomeServer | |||
| from synapse.storage import are_all_users_on_domain | ||||
| from synapse.storage.engines import IncorrectDatabaseSetup, create_engine | ||||
| from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database | ||||
| from synapse.util.caches import CACHE_SIZE_FACTOR | ||||
| from synapse.util.httpresourcetree import create_resource_tree | ||||
| from synapse.util.logcontext import LoggingContext | ||||
| from synapse.util.manhole import manhole | ||||
|  | @ -402,6 +403,10 @@ def run(hs): | |||
| 
 | ||||
|     stats = {} | ||||
| 
 | ||||
|     # Contains the list of processes we will be monitoring | ||||
|     # currently either 0 or 1 | ||||
|     stats_process = [] | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def phone_stats_home(): | ||||
|         logger.info("Gathering stats for reporting") | ||||
|  | @ -425,8 +430,21 @@ def run(hs): | |||
|         stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms() | ||||
|         stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() | ||||
| 
 | ||||
|         r30_results = yield hs.get_datastore().count_r30_users() | ||||
|         for name, count in r30_results.iteritems(): | ||||
|             stats["r30_users_" + name] = count | ||||
| 
 | ||||
|         daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() | ||||
|         stats["daily_sent_messages"] = daily_sent_messages | ||||
|         stats["cache_factor"] = CACHE_SIZE_FACTOR | ||||
|         stats["event_cache_size"] = hs.config.event_cache_size | ||||
| 
 | ||||
|         if len(stats_process) > 0: | ||||
|             stats["memory_rss"] = 0 | ||||
|             stats["cpu_average"] = 0 | ||||
|             for process in stats_process: | ||||
|                 stats["memory_rss"] += process.memory_info().rss | ||||
|                 stats["cpu_average"] += int(process.cpu_percent(interval=None)) | ||||
| 
 | ||||
|         logger.info("Reporting stats to matrix.org: %s" % (stats,)) | ||||
|         try: | ||||
|  | @ -437,10 +455,32 @@ def run(hs): | |||
|         except Exception as e: | ||||
|             logger.warn("Error reporting stats: %s", e) | ||||
| 
 | ||||
|     def performance_stats_init(): | ||||
|         try: | ||||
|             import psutil | ||||
|             process = psutil.Process() | ||||
|             # Ensure we can fetch both, and make the initial request for cpu_percent | ||||
|             # so the next request will use this as the initial point. | ||||
|             process.memory_info().rss | ||||
|             process.cpu_percent(interval=None) | ||||
|             logger.info("report_stats can use psutil") | ||||
|             stats_process.append(process) | ||||
|         except (ImportError, AttributeError): | ||||
|             logger.warn( | ||||
|                 "report_stats enabled but psutil is not installed or incorrect version." | ||||
|                 " Disabling reporting of memory/cpu stats." | ||||
|                 " Ensuring psutil is available will help matrix.org track performance" | ||||
|                 " changes across releases." | ||||
|             ) | ||||
| 
 | ||||
|     if hs.config.report_stats: | ||||
|         logger.info("Scheduling stats reporting for 3 hour intervals") | ||||
|         clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000) | ||||
| 
 | ||||
|         # We need to defer this init for the cases that we daemonize | ||||
|         # otherwise the process ID we get is that of the non-daemon process | ||||
|         clock.call_later(0, performance_stats_init) | ||||
| 
 | ||||
|         # We wait 5 minutes to send the first set of stats as the server can | ||||
|         # be quite busy the first few minutes | ||||
|         clock.call_later(5 * 60, phone_stats_home) | ||||
|  |  | |||
|  | @ -38,7 +38,7 @@ def pid_running(pid): | |||
|     try: | ||||
|         os.kill(pid, 0) | ||||
|         return True | ||||
|     except OSError, err: | ||||
|     except OSError as err: | ||||
|         if err.errno == errno.EPERM: | ||||
|             return True | ||||
|         return False | ||||
|  | @ -98,7 +98,7 @@ def stop(pidfile, app): | |||
|         try: | ||||
|             os.kill(pid, signal.SIGTERM) | ||||
|             write("stopped %s" % (app,), colour=GREEN) | ||||
|         except OSError, err: | ||||
|         except OSError as err: | ||||
|             if err.errno == errno.ESRCH: | ||||
|                 write("%s not running" % (app,), colour=YELLOW) | ||||
|             elif err.errno == errno.EPERM: | ||||
|  | @ -252,6 +252,7 @@ def main(): | |||
|             for running_pid in running_pids: | ||||
|                 while pid_running(running_pid): | ||||
|                     time.sleep(0.2) | ||||
|             write("All processes exited; now restarting...") | ||||
| 
 | ||||
|     if action == "start" or action == "restart": | ||||
|         if start_stop_synapse: | ||||
|  |  | |||
|  | @ -1,5 +1,6 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2014-2016 OpenMarket Ltd | ||||
| # Copyright 2018 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
|  | @ -613,6 +614,19 @@ class TransportLayerClient(object): | |||
|             ignore_backoff=True, | ||||
|         ) | ||||
| 
 | ||||
|     @log_function | ||||
|     def join_group(self, destination, group_id, user_id, content): | ||||
|         """Attempts to join a group | ||||
|         """ | ||||
|         path = PREFIX + "/groups/%s/users/%s/join" % (group_id, user_id) | ||||
| 
 | ||||
|         return self.client.post_json( | ||||
|             destination=destination, | ||||
|             path=path, | ||||
|             data=content, | ||||
|             ignore_backoff=True, | ||||
|         ) | ||||
| 
 | ||||
|     @log_function | ||||
|     def invite_to_group(self, destination, group_id, user_id, requester_user_id, content): | ||||
|         """Invite a user to a group | ||||
|  | @ -856,6 +870,21 @@ class TransportLayerClient(object): | |||
|             ignore_backoff=True, | ||||
|         ) | ||||
| 
 | ||||
|     @log_function | ||||
|     def set_group_join_policy(self, destination, group_id, requester_user_id, | ||||
|                               content): | ||||
|         """Sets the join policy for a group | ||||
|         """ | ||||
|         path = PREFIX + "/groups/%s/settings/m.join_policy" % (group_id,) | ||||
| 
 | ||||
|         return self.client.put_json( | ||||
|             destination=destination, | ||||
|             path=path, | ||||
|             args={"requester_user_id": requester_user_id}, | ||||
|             data=content, | ||||
|             ignore_backoff=True, | ||||
|         ) | ||||
| 
 | ||||
|     @log_function | ||||
|     def delete_group_summary_user(self, destination, group_id, requester_user_id, | ||||
|                                   user_id, role_id): | ||||
|  |  | |||
|  | @ -1,5 +1,6 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2014-2016 OpenMarket Ltd | ||||
| # Copyright 2018 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
|  | @ -802,6 +803,23 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet): | |||
|         defer.returnValue((200, new_content)) | ||||
| 
 | ||||
| 
 | ||||
| class FederationGroupsJoinServlet(BaseFederationServlet): | ||||
|     """Attempt to join a group | ||||
|     """ | ||||
|     PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join$" | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, origin, content, query, group_id, user_id): | ||||
|         if get_domain_from_id(user_id) != origin: | ||||
|             raise SynapseError(403, "user_id doesn't match origin") | ||||
| 
 | ||||
|         new_content = yield self.handler.join_group( | ||||
|             group_id, user_id, content, | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue((200, new_content)) | ||||
| 
 | ||||
| 
 | ||||
| class FederationGroupsRemoveUserServlet(BaseFederationServlet): | ||||
|     """Leave or kick a user from the group | ||||
|     """ | ||||
|  | @ -1124,6 +1142,24 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): | |||
|         defer.returnValue((200, resp)) | ||||
| 
 | ||||
| 
 | ||||
| class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): | ||||
|     """Sets whether a group is joinable without an invite or knock | ||||
|     """ | ||||
|     PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy$" | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_PUT(self, origin, content, query, group_id): | ||||
|         requester_user_id = parse_string_from_args(query, "requester_user_id") | ||||
|         if get_domain_from_id(requester_user_id) != origin: | ||||
|             raise SynapseError(403, "requester_user_id doesn't match origin") | ||||
| 
 | ||||
|         new_content = yield self.handler.set_group_join_policy( | ||||
|             group_id, requester_user_id, content | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue((200, new_content)) | ||||
| 
 | ||||
| 
 | ||||
| FEDERATION_SERVLET_CLASSES = ( | ||||
|     FederationSendServlet, | ||||
|     FederationPullServlet, | ||||
|  | @ -1163,6 +1199,7 @@ GROUP_SERVER_SERVLET_CLASSES = ( | |||
|     FederationGroupsInvitedUsersServlet, | ||||
|     FederationGroupsInviteServlet, | ||||
|     FederationGroupsAcceptInviteServlet, | ||||
|     FederationGroupsJoinServlet, | ||||
|     FederationGroupsRemoveUserServlet, | ||||
|     FederationGroupsSummaryRoomsServlet, | ||||
|     FederationGroupsCategoriesServlet, | ||||
|  | @ -1172,6 +1209,7 @@ GROUP_SERVER_SERVLET_CLASSES = ( | |||
|     FederationGroupsSummaryUsersServlet, | ||||
|     FederationGroupsAddRoomsServlet, | ||||
|     FederationGroupsAddRoomsConfigServlet, | ||||
|     FederationGroupsSettingJoinPolicyServlet, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,5 +1,6 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2017 Vector Creations Ltd | ||||
| # Copyright 2018 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
|  | @ -205,6 +206,28 @@ class GroupsServerHandler(object): | |||
| 
 | ||||
|         defer.returnValue({}) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def set_group_join_policy(self, group_id, requester_user_id, content): | ||||
|         """Sets the group join policy. | ||||
| 
 | ||||
|         Currently supported policies are: | ||||
|          - "invite": an invite must be received and accepted in order to join. | ||||
|          - "open": anyone can join. | ||||
|         """ | ||||
|         yield self.check_group_is_ours( | ||||
|             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id | ||||
|         ) | ||||
| 
 | ||||
|         join_policy = _parse_join_policy_from_contents(content) | ||||
|         if join_policy is None: | ||||
|             raise SynapseError( | ||||
|                 400, "No value specified for 'm.join_policy'" | ||||
|             ) | ||||
| 
 | ||||
|         yield self.store.set_group_join_policy(group_id, join_policy=join_policy) | ||||
| 
 | ||||
|         defer.returnValue({}) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_group_categories(self, group_id, requester_user_id): | ||||
|         """Get all categories in a group (as seen by user) | ||||
|  | @ -381,9 +404,16 @@ class GroupsServerHandler(object): | |||
| 
 | ||||
|         yield self.check_group_is_ours(group_id, requester_user_id) | ||||
| 
 | ||||
|         group_description = yield self.store.get_group(group_id) | ||||
|         group = yield self.store.get_group(group_id) | ||||
| 
 | ||||
|         if group: | ||||
|             cols = [ | ||||
|                 "name", "short_description", "long_description", | ||||
|                 "avatar_url", "is_public", | ||||
|             ] | ||||
|             group_description = {key: group[key] for key in cols} | ||||
|             group_description["is_openly_joinable"] = group["join_policy"] == "open" | ||||
| 
 | ||||
|         if group_description: | ||||
|             defer.returnValue(group_description) | ||||
|         else: | ||||
|             raise SynapseError(404, "Unknown group") | ||||
|  | @ -654,6 +684,40 @@ class GroupsServerHandler(object): | |||
|         else: | ||||
|             raise SynapseError(502, "Unknown state returned by HS") | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _add_user(self, group_id, user_id, content): | ||||
|         """Add a user to a group based on a content dict. | ||||
| 
 | ||||
|         See accept_invite, join_group. | ||||
|         """ | ||||
|         if not self.hs.is_mine_id(user_id): | ||||
|             local_attestation = self.attestations.create_attestation( | ||||
|                 group_id, user_id, | ||||
|             ) | ||||
| 
 | ||||
|             remote_attestation = content["attestation"] | ||||
| 
 | ||||
|             yield self.attestations.verify_attestation( | ||||
|                 remote_attestation, | ||||
|                 user_id=user_id, | ||||
|                 group_id=group_id, | ||||
|             ) | ||||
|         else: | ||||
|             local_attestation = None | ||||
|             remote_attestation = None | ||||
| 
 | ||||
|         is_public = _parse_visibility_from_contents(content) | ||||
| 
 | ||||
|         yield self.store.add_user_to_group( | ||||
|             group_id, user_id, | ||||
|             is_admin=False, | ||||
|             is_public=is_public, | ||||
|             local_attestation=local_attestation, | ||||
|             remote_attestation=remote_attestation, | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue(local_attestation) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def accept_invite(self, group_id, requester_user_id, content): | ||||
|         """User tries to accept an invite to the group. | ||||
|  | @ -670,30 +734,27 @@ class GroupsServerHandler(object): | |||
|         if not is_invited: | ||||
|             raise SynapseError(403, "User not invited to group") | ||||
| 
 | ||||
|         if not self.hs.is_mine_id(requester_user_id): | ||||
|             local_attestation = self.attestations.create_attestation( | ||||
|                 group_id, requester_user_id, | ||||
|             ) | ||||
|             remote_attestation = content["attestation"] | ||||
|         local_attestation = yield self._add_user(group_id, requester_user_id, content) | ||||
| 
 | ||||
|             yield self.attestations.verify_attestation( | ||||
|                 remote_attestation, | ||||
|                 user_id=requester_user_id, | ||||
|                 group_id=group_id, | ||||
|             ) | ||||
|         else: | ||||
|             local_attestation = None | ||||
|             remote_attestation = None | ||||
|         defer.returnValue({ | ||||
|             "state": "join", | ||||
|             "attestation": local_attestation, | ||||
|         }) | ||||
| 
 | ||||
|         is_public = _parse_visibility_from_contents(content) | ||||
|     @defer.inlineCallbacks | ||||
|     def join_group(self, group_id, requester_user_id, content): | ||||
|         """User tries to join the group. | ||||
| 
 | ||||
|         yield self.store.add_user_to_group( | ||||
|             group_id, requester_user_id, | ||||
|             is_admin=False, | ||||
|             is_public=is_public, | ||||
|             local_attestation=local_attestation, | ||||
|             remote_attestation=remote_attestation, | ||||
|         This will error if the group requires an invite/knock to join | ||||
|         """ | ||||
| 
 | ||||
|         group_info = yield self.check_group_is_ours( | ||||
|             group_id, requester_user_id, and_exists=True | ||||
|         ) | ||||
|         if group_info['join_policy'] != "open": | ||||
|             raise SynapseError(403, "Group is not publicly joinable") | ||||
| 
 | ||||
|         local_attestation = yield self._add_user(group_id, requester_user_id, content) | ||||
| 
 | ||||
|         defer.returnValue({ | ||||
|             "state": "join", | ||||
|  | @ -835,6 +896,31 @@ class GroupsServerHandler(object): | |||
|         }) | ||||
| 
 | ||||
| 
 | ||||
| def _parse_join_policy_from_contents(content): | ||||
|     """Given a content for a request, return the specified join policy or None | ||||
|     """ | ||||
| 
 | ||||
|     join_policy_dict = content.get("m.join_policy") | ||||
|     if join_policy_dict: | ||||
|         return _parse_join_policy_dict(join_policy_dict) | ||||
|     else: | ||||
|         return None | ||||
| 
 | ||||
| 
 | ||||
| def _parse_join_policy_dict(join_policy_dict): | ||||
|     """Given a dict for the "m.join_policy" config return the join policy specified | ||||
|     """ | ||||
|     join_policy_type = join_policy_dict.get("type") | ||||
|     if not join_policy_type: | ||||
|         return "invite" | ||||
| 
 | ||||
|     if join_policy_type not in ("invite", "open"): | ||||
|         raise SynapseError( | ||||
|             400, "Synapse only supports 'invite'/'open' join rule" | ||||
|         ) | ||||
|     return join_policy_type | ||||
| 
 | ||||
| 
 | ||||
| def _parse_visibility_from_contents(content): | ||||
|     """Given a content for a request parse out whether the entity should be | ||||
|     public or not | ||||
|  |  | |||
|  | @ -155,7 +155,7 @@ class DeviceHandler(BaseHandler): | |||
| 
 | ||||
|         try: | ||||
|             yield self.store.delete_device(user_id, device_id) | ||||
|         except errors.StoreError, e: | ||||
|         except errors.StoreError as e: | ||||
|             if e.code == 404: | ||||
|                 # no match | ||||
|                 pass | ||||
|  | @ -204,7 +204,7 @@ class DeviceHandler(BaseHandler): | |||
| 
 | ||||
|         try: | ||||
|             yield self.store.delete_devices(user_id, device_ids) | ||||
|         except errors.StoreError, e: | ||||
|         except errors.StoreError as e: | ||||
|             if e.code == 404: | ||||
|                 # no match | ||||
|                 pass | ||||
|  | @ -243,7 +243,7 @@ class DeviceHandler(BaseHandler): | |||
|                 new_display_name=content.get("display_name") | ||||
|             ) | ||||
|             yield self.notify_device_update(user_id, [device_id]) | ||||
|         except errors.StoreError, e: | ||||
|         except errors.StoreError as e: | ||||
|             if e.code == 404: | ||||
|                 raise errors.NotFoundError() | ||||
|             else: | ||||
|  |  | |||
|  | @ -1,5 +1,6 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2016 OpenMarket Ltd | ||||
| # Copyright 2018 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
|  | @ -134,23 +135,8 @@ class E2eKeysHandler(object): | |||
|                     if user_id in destination_query: | ||||
|                         results[user_id] = keys | ||||
| 
 | ||||
|             except CodeMessageException as e: | ||||
|                 failures[destination] = { | ||||
|                     "status": e.code, "message": e.message | ||||
|                 } | ||||
|             except NotRetryingDestination as e: | ||||
|                 failures[destination] = { | ||||
|                     "status": 503, "message": "Not ready for retry", | ||||
|                 } | ||||
|             except FederationDeniedError as e: | ||||
|                 failures[destination] = { | ||||
|                     "status": 403, "message": "Federation Denied", | ||||
|                 } | ||||
|             except Exception as e: | ||||
|                 # include ConnectionRefused and other errors | ||||
|                 failures[destination] = { | ||||
|                     "status": 503, "message": e.message | ||||
|                 } | ||||
|                 failures[destination] = _exception_to_failure(e) | ||||
| 
 | ||||
|         yield make_deferred_yieldable(defer.gatherResults([ | ||||
|             preserve_fn(do_remote_query)(destination) | ||||
|  | @ -252,19 +238,8 @@ class E2eKeysHandler(object): | |||
|                 for user_id, keys in remote_result["one_time_keys"].items(): | ||||
|                     if user_id in device_keys: | ||||
|                         json_result[user_id] = keys | ||||
|             except CodeMessageException as e: | ||||
|                 failures[destination] = { | ||||
|                     "status": e.code, "message": e.message | ||||
|                 } | ||||
|             except NotRetryingDestination as e: | ||||
|                 failures[destination] = { | ||||
|                     "status": 503, "message": "Not ready for retry", | ||||
|                 } | ||||
|             except Exception as e: | ||||
|                 # include ConnectionRefused and other errors | ||||
|                 failures[destination] = { | ||||
|                     "status": 503, "message": e.message | ||||
|                 } | ||||
|                 failures[destination] = _exception_to_failure(e) | ||||
| 
 | ||||
|         yield make_deferred_yieldable(defer.gatherResults([ | ||||
|             preserve_fn(claim_client_keys)(destination) | ||||
|  | @ -362,6 +337,31 @@ class E2eKeysHandler(object): | |||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def _exception_to_failure(e): | ||||
|     if isinstance(e, CodeMessageException): | ||||
|         return { | ||||
|             "status": e.code, "message": e.message, | ||||
|         } | ||||
| 
 | ||||
|     if isinstance(e, NotRetryingDestination): | ||||
|         return { | ||||
|             "status": 503, "message": "Not ready for retry", | ||||
|         } | ||||
| 
 | ||||
|     if isinstance(e, FederationDeniedError): | ||||
|         return { | ||||
|             "status": 403, "message": "Federation Denied", | ||||
|         } | ||||
| 
 | ||||
|     # include ConnectionRefused and other errors | ||||
|     # | ||||
|     # Note that some Exceptions (notably twisted's ResponseFailed etc) don't | ||||
|     # give a string for e.message, which simplejson then fails to serialize. | ||||
|     return { | ||||
|         "status": 503, "message": str(e.message), | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| def _one_time_keys_match(old_key_json, new_key): | ||||
|     old_key = json.loads(old_key_json) | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,5 +1,6 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2017 Vector Creations Ltd | ||||
| # Copyright 2018 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
|  | @ -90,6 +91,8 @@ class GroupsLocalHandler(object): | |||
|     get_group_role = _create_rerouter("get_group_role") | ||||
|     get_group_roles = _create_rerouter("get_group_roles") | ||||
| 
 | ||||
|     set_group_join_policy = _create_rerouter("set_group_join_policy") | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_group_summary(self, group_id, requester_user_id): | ||||
|         """Get the group summary for a group. | ||||
|  | @ -226,7 +229,45 @@ class GroupsLocalHandler(object): | |||
|     def join_group(self, group_id, user_id, content): | ||||
|         """Request to join a group | ||||
|         """ | ||||
|         raise NotImplementedError()  # TODO | ||||
|         if self.is_mine_id(group_id): | ||||
|             yield self.groups_server_handler.join_group( | ||||
|                 group_id, user_id, content | ||||
|             ) | ||||
|             local_attestation = None | ||||
|             remote_attestation = None | ||||
|         else: | ||||
|             local_attestation = self.attestations.create_attestation(group_id, user_id) | ||||
|             content["attestation"] = local_attestation | ||||
| 
 | ||||
|             res = yield self.transport_client.join_group( | ||||
|                 get_domain_from_id(group_id), group_id, user_id, content, | ||||
|             ) | ||||
| 
 | ||||
|             remote_attestation = res["attestation"] | ||||
| 
 | ||||
|             yield self.attestations.verify_attestation( | ||||
|                 remote_attestation, | ||||
|                 group_id=group_id, | ||||
|                 user_id=user_id, | ||||
|                 server_name=get_domain_from_id(group_id), | ||||
|             ) | ||||
| 
 | ||||
|         # TODO: Check that the group is public and we're being added publically | ||||
|         is_publicised = content.get("publicise", False) | ||||
| 
 | ||||
|         token = yield self.store.register_user_group_membership( | ||||
|             group_id, user_id, | ||||
|             membership="join", | ||||
|             is_admin=False, | ||||
|             local_attestation=local_attestation, | ||||
|             remote_attestation=remote_attestation, | ||||
|             is_publicised=is_publicised, | ||||
|         ) | ||||
|         self.notifier.on_new_event( | ||||
|             "groups_key", token, users=[user_id], | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue({}) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def accept_invite(self, group_id, user_id, content): | ||||
|  |  | |||
|  | @ -15,6 +15,11 @@ | |||
| # limitations under the License. | ||||
| 
 | ||||
| """Utilities for interacting with Identity Servers""" | ||||
| 
 | ||||
| import logging | ||||
| 
 | ||||
| import simplejson as json | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.api.errors import ( | ||||
|  | @ -24,9 +29,6 @@ from ._base import BaseHandler | |||
| from synapse.util.async import run_on_reactor | ||||
| from synapse.api.errors import SynapseError, Codes | ||||
| 
 | ||||
| import json | ||||
| import logging | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -27,7 +27,7 @@ from synapse.types import ( | |||
| from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter | ||||
| from synapse.util.logcontext import preserve_fn, run_in_background | ||||
| from synapse.util.metrics import measure_func | ||||
| from synapse.util.frozenutils import unfreeze | ||||
| from synapse.util.frozenutils import frozendict_json_encoder | ||||
| from synapse.util.stringutils import random_string | ||||
| from synapse.visibility import filter_events_for_client | ||||
| from synapse.replication.http.send_event import send_event_to_master | ||||
|  | @ -678,7 +678,7 @@ class EventCreationHandler(object): | |||
| 
 | ||||
|         # Ensure that we can round trip before trying to persist in db | ||||
|         try: | ||||
|             dump = simplejson.dumps(unfreeze(event.content)) | ||||
|             dump = frozendict_json_encoder.encode(event.content) | ||||
|             simplejson.loads(dump) | ||||
|         except Exception: | ||||
|             logger.exception("Failed to encode content: %r", event.content) | ||||
|  |  | |||
|  | @ -286,7 +286,8 @@ class MatrixFederationHttpClient(object): | |||
|         headers_dict[b"Authorization"] = auth_headers | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def put_json(self, destination, path, data={}, json_data_callback=None, | ||||
|     def put_json(self, destination, path, args={}, data={}, | ||||
|                  json_data_callback=None, | ||||
|                  long_retries=False, timeout=None, | ||||
|                  ignore_backoff=False, | ||||
|                  backoff_on_404=False): | ||||
|  | @ -296,6 +297,7 @@ class MatrixFederationHttpClient(object): | |||
|             destination (str): The remote server to send the HTTP request | ||||
|                 to. | ||||
|             path (str): The HTTP path. | ||||
|             args (dict): query params | ||||
|             data (dict): A dict containing the data that will be used as | ||||
|                 the request body. This will be encoded as JSON. | ||||
|             json_data_callback (callable): A callable returning the dict to | ||||
|  | @ -342,6 +344,7 @@ class MatrixFederationHttpClient(object): | |||
|             path, | ||||
|             body_callback=body_callback, | ||||
|             headers_dict={"Content-Type": ["application/json"]}, | ||||
|             query_bytes=encode_query_args(args), | ||||
|             long_retries=long_retries, | ||||
|             timeout=timeout, | ||||
|             ignore_backoff=ignore_backoff, | ||||
|  | @ -373,6 +376,7 @@ class MatrixFederationHttpClient(object): | |||
|                 giving up. None indicates no timeout. | ||||
|             ignore_backoff (bool): true to ignore the historical backoff data and | ||||
|                 try the request anyway. | ||||
|             args (dict): query params | ||||
|         Returns: | ||||
|             Deferred: Succeeds when we get a 2xx HTTP response. The result | ||||
|             will be the decoded JSON body. | ||||
|  |  | |||
|  | @ -113,6 +113,11 @@ response_db_sched_duration = metrics.register_counter( | |||
|     "response_db_sched_duration_seconds", labels=["method", "servlet", "tag"] | ||||
| ) | ||||
| 
 | ||||
| # size in bytes of the response written | ||||
| response_size = metrics.register_counter( | ||||
|     "response_size", labels=["method", "servlet", "tag"] | ||||
| ) | ||||
| 
 | ||||
| _next_request_id = 0 | ||||
| 
 | ||||
| 
 | ||||
|  | @ -426,6 +431,8 @@ class RequestMetrics(object): | |||
|             context.db_sched_duration_ms / 1000., request.method, self.name, tag | ||||
|         ) | ||||
| 
 | ||||
|         response_size.inc_by(request.sentLength, request.method, self.name, tag) | ||||
| 
 | ||||
| 
 | ||||
| class RootRedirect(resource.Resource): | ||||
|     """Redirects the root '/' path to another path.""" | ||||
|  |  | |||
|  | @ -24,6 +24,8 @@ import simplejson | |||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| _json_encoder = simplejson.JSONEncoder(namedtuple_as_object=False) | ||||
| 
 | ||||
| 
 | ||||
| class Command(object): | ||||
|     """The base command class. | ||||
|  | @ -107,7 +109,7 @@ class RdataCommand(Command): | |||
|         return " ".join(( | ||||
|             self.stream_name, | ||||
|             str(self.token) if self.token is not None else "batch", | ||||
|             simplejson.dumps(self.row, namedtuple_as_object=False), | ||||
|             _json_encoder.encode(self.row), | ||||
|         )) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -302,7 +304,7 @@ class InvalidateCacheCommand(Command): | |||
| 
 | ||||
|     def to_line(self): | ||||
|         return " ".join(( | ||||
|             self.cache_func, simplejson.dumps(self.keys, namedtuple_as_object=False) | ||||
|             self.cache_func, _json_encoder.encode(self.keys), | ||||
|         )) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -334,7 +336,7 @@ class UserIpCommand(Command): | |||
|         ) | ||||
| 
 | ||||
|     def to_line(self): | ||||
|         return self.user_id + " " + simplejson.dumps(( | ||||
|         return self.user_id + " " + _json_encoder.encode(( | ||||
|             self.access_token, self.ip, self.user_agent, self.device_id, | ||||
|             self.last_seen, | ||||
|         )) | ||||
|  |  | |||
|  | @ -655,7 +655,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet): | |||
|             content=event_content, | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue((200, {})) | ||||
|         return_value = {} | ||||
| 
 | ||||
|         if membership_action == "join": | ||||
|             return_value["room_id"] = room_id | ||||
| 
 | ||||
|         defer.returnValue((200, return_value)) | ||||
| 
 | ||||
|     def _has_3pid_invite_keys(self, content): | ||||
|         for key in {"id_server", "medium", "address"}: | ||||
|  |  | |||
|  | @ -1,5 +1,6 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2017 Vector Creations Ltd | ||||
| # Copyright 2018 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
|  | @ -401,6 +402,32 @@ class GroupInvitedUsersServlet(RestServlet): | |||
|         defer.returnValue((200, result)) | ||||
| 
 | ||||
| 
 | ||||
| class GroupSettingJoinPolicyServlet(RestServlet): | ||||
|     """Set group join policy | ||||
|     """ | ||||
|     PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$") | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         super(GroupSettingJoinPolicyServlet, self).__init__() | ||||
|         self.auth = hs.get_auth() | ||||
|         self.groups_handler = hs.get_groups_local_handler() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_PUT(self, request, group_id): | ||||
|         requester = yield self.auth.get_user_by_req(request) | ||||
|         requester_user_id = requester.user.to_string() | ||||
| 
 | ||||
|         content = parse_json_object_from_request(request) | ||||
| 
 | ||||
|         result = yield self.groups_handler.set_group_join_policy( | ||||
|             group_id, | ||||
|             requester_user_id, | ||||
|             content, | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue((200, result)) | ||||
| 
 | ||||
| 
 | ||||
| class GroupCreateServlet(RestServlet): | ||||
|     """Create a group | ||||
|     """ | ||||
|  | @ -738,6 +765,7 @@ def register_servlets(hs, http_server): | |||
|     GroupInvitedUsersServlet(hs).register(http_server) | ||||
|     GroupUsersServlet(hs).register(http_server) | ||||
|     GroupRoomServlet(hs).register(http_server) | ||||
|     GroupSettingJoinPolicyServlet(hs).register(http_server) | ||||
|     GroupCreateServlet(hs).register(http_server) | ||||
|     GroupAdminRoomsServlet(hs).register(http_server) | ||||
|     GroupAdminRoomsConfigServlet(hs).register(http_server) | ||||
|  |  | |||
|  | @ -483,33 +483,34 @@ class StateResolutionHandler(object): | |||
|                     key: e_ids.pop() for key, e_ids in state.iteritems() | ||||
|                 } | ||||
| 
 | ||||
|             # if the new state matches any of the input state groups, we can | ||||
|             # use that state group again. Otherwise we will generate a state_id | ||||
|             # which will be used as a cache key for future resolutions, but | ||||
|             # not get persisted. | ||||
|             state_group = None | ||||
|             new_state_event_ids = frozenset(new_state.itervalues()) | ||||
|             for sg, events in state_groups_ids.iteritems(): | ||||
|                 if new_state_event_ids == frozenset(e_id for e_id in events): | ||||
|                     state_group = sg | ||||
|                     break | ||||
|             with Measure(self.clock, "state.create_group_ids"): | ||||
|                 # if the new state matches any of the input state groups, we can | ||||
|                 # use that state group again. Otherwise we will generate a state_id | ||||
|                 # which will be used as a cache key for future resolutions, but | ||||
|                 # not get persisted. | ||||
|                 state_group = None | ||||
|                 new_state_event_ids = frozenset(new_state.itervalues()) | ||||
|                 for sg, events in state_groups_ids.iteritems(): | ||||
|                     if new_state_event_ids == frozenset(e_id for e_id in events): | ||||
|                         state_group = sg | ||||
|                         break | ||||
| 
 | ||||
|             # TODO: We want to create a state group for this set of events, to | ||||
|             # increase cache hits, but we need to make sure that it doesn't | ||||
|             # end up as a prev_group without being added to the database | ||||
|                 # TODO: We want to create a state group for this set of events, to | ||||
|                 # increase cache hits, but we need to make sure that it doesn't | ||||
|                 # end up as a prev_group without being added to the database | ||||
| 
 | ||||
|             prev_group = None | ||||
|             delta_ids = None | ||||
|             for old_group, old_ids in state_groups_ids.iteritems(): | ||||
|                 if not set(new_state) - set(old_ids): | ||||
|                     n_delta_ids = { | ||||
|                         k: v | ||||
|                         for k, v in new_state.iteritems() | ||||
|                         if old_ids.get(k) != v | ||||
|                     } | ||||
|                     if not delta_ids or len(n_delta_ids) < len(delta_ids): | ||||
|                         prev_group = old_group | ||||
|                         delta_ids = n_delta_ids | ||||
|                 prev_group = None | ||||
|                 delta_ids = None | ||||
|                 for old_group, old_ids in state_groups_ids.iteritems(): | ||||
|                     if not set(new_state) - set(old_ids): | ||||
|                         n_delta_ids = { | ||||
|                             k: v | ||||
|                             for k, v in new_state.iteritems() | ||||
|                             if old_ids.get(k) != v | ||||
|                         } | ||||
|                         if not delta_ids or len(n_delta_ids) < len(delta_ids): | ||||
|                             prev_group = old_group | ||||
|                             delta_ids = n_delta_ids | ||||
| 
 | ||||
|             cache = _StateCacheEntry( | ||||
|                 state=new_state, | ||||
|  |  | |||
|  | @ -14,8 +14,6 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.storage.devices import DeviceStore | ||||
| from .appservice import ( | ||||
|     ApplicationServiceStore, ApplicationServiceTransactionStore | ||||
|  | @ -244,13 +242,12 @@ class DataStore(RoomMemberStore, RoomStore, | |||
| 
 | ||||
|         return [UserPresenceState(**row) for row in rows] | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def count_daily_users(self): | ||||
|         """ | ||||
|         Counts the number of users who used this homeserver in the last 24 hours. | ||||
|         """ | ||||
|         def _count_users(txn): | ||||
|             yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24), | ||||
|             yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) | ||||
| 
 | ||||
|             sql = """ | ||||
|                 SELECT COALESCE(count(*), 0) FROM ( | ||||
|  | @ -264,8 +261,91 @@ class DataStore(RoomMemberStore, RoomStore, | |||
|             count, = txn.fetchone() | ||||
|             return count | ||||
| 
 | ||||
|         ret = yield self.runInteraction("count_users", _count_users) | ||||
|         defer.returnValue(ret) | ||||
|         return self.runInteraction("count_users", _count_users) | ||||
| 
 | ||||
|     def count_r30_users(self): | ||||
|         """ | ||||
|         Counts the number of 30 day retained users, defined as:- | ||||
|          * Users who have created their accounts more than 30 days | ||||
|          * Where last seen at most 30 days ago | ||||
|          * Where account creation and last_seen are > 30 days | ||||
| 
 | ||||
|          Returns counts globaly for a given user as well as breaking | ||||
|          by platform | ||||
|         """ | ||||
|         def _count_r30_users(txn): | ||||
|             thirty_days_in_secs = 86400 * 30 | ||||
|             now = int(self._clock.time_msec()) | ||||
|             thirty_days_ago_in_secs = now - thirty_days_in_secs | ||||
| 
 | ||||
|             sql = """ | ||||
|                 SELECT platform, COALESCE(count(*), 0) FROM ( | ||||
|                      SELECT | ||||
|                         users.name, platform, users.creation_ts * 1000, | ||||
|                         MAX(uip.last_seen) | ||||
|                      FROM users | ||||
|                      INNER JOIN ( | ||||
|                          SELECT | ||||
|                          user_id, | ||||
|                          last_seen, | ||||
|                          CASE | ||||
|                              WHEN user_agent LIKE '%Android%' THEN 'android' | ||||
|                              WHEN user_agent LIKE '%iOS%' THEN 'ios' | ||||
|                              WHEN user_agent LIKE '%Electron%' THEN 'electron' | ||||
|                              WHEN user_agent LIKE '%Mozilla%' THEN 'web' | ||||
|                              WHEN user_agent LIKE '%Gecko%' THEN 'web' | ||||
|                              ELSE 'unknown' | ||||
|                          END | ||||
|                          AS platform | ||||
|                          FROM user_ips | ||||
|                      ) uip | ||||
|                      ON users.name = uip.user_id | ||||
|                      AND users.appservice_id is NULL | ||||
|                      AND users.creation_ts < ? | ||||
|                      AND uip.last_seen/1000 > ? | ||||
|                      AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 | ||||
|                      GROUP BY users.name, platform, users.creation_ts | ||||
|                 ) u GROUP BY platform | ||||
|             """ | ||||
| 
 | ||||
|             results = {} | ||||
|             txn.execute(sql, (thirty_days_ago_in_secs, | ||||
|                               thirty_days_ago_in_secs)) | ||||
| 
 | ||||
|             for row in txn: | ||||
|                 if row[0] is 'unknown': | ||||
|                     pass | ||||
|                 results[row[0]] = row[1] | ||||
| 
 | ||||
|             sql = """ | ||||
|                 SELECT COALESCE(count(*), 0) FROM ( | ||||
|                     SELECT users.name, users.creation_ts * 1000, | ||||
|                                                         MAX(uip.last_seen) | ||||
|                     FROM users | ||||
|                     INNER JOIN ( | ||||
|                         SELECT | ||||
|                         user_id, | ||||
|                         last_seen | ||||
|                         FROM user_ips | ||||
|                     ) uip | ||||
|                     ON users.name = uip.user_id | ||||
|                     AND appservice_id is NULL | ||||
|                     AND users.creation_ts < ? | ||||
|                     AND uip.last_seen/1000 > ? | ||||
|                     AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 | ||||
|                     GROUP BY users.name, users.creation_ts | ||||
|                 ) u | ||||
|             """ | ||||
| 
 | ||||
|             txn.execute(sql, (thirty_days_ago_in_secs, | ||||
|                               thirty_days_ago_in_secs)) | ||||
| 
 | ||||
|             count, = txn.fetchone() | ||||
|             results['all'] = count | ||||
| 
 | ||||
|             return results | ||||
| 
 | ||||
|         return self.runInteraction("count_r30_users", _count_r30_users) | ||||
| 
 | ||||
|     def get_users(self): | ||||
|         """Function to reterive a list of users in users table. | ||||
|  |  | |||
|  | @ -48,6 +48,13 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): | |||
|             columns=["user_id", "device_id", "last_seen"], | ||||
|         ) | ||||
| 
 | ||||
|         self.register_background_index_update( | ||||
|             "user_ips_last_seen_index", | ||||
|             index_name="user_ips_last_seen", | ||||
|             table="user_ips", | ||||
|             columns=["user_id", "last_seen"], | ||||
|         ) | ||||
| 
 | ||||
|         # (user_id, access_token, ip) -> (user_agent, device_id, last_seen) | ||||
|         self._batch_row_update = {} | ||||
| 
 | ||||
|  |  | |||
|  | @ -14,15 +14,19 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from synapse.storage.events_worker import EventsWorkerStore | ||||
| from collections import OrderedDict, deque, namedtuple | ||||
| from functools import wraps | ||||
| import logging | ||||
| 
 | ||||
| import simplejson as json | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.events import USE_FROZEN_DICTS | ||||
| 
 | ||||
| from synapse.storage.events_worker import EventsWorkerStore | ||||
| from synapse.util.async import ObservableDeferred | ||||
| from synapse.util.frozenutils import frozendict_json_encoder | ||||
| from synapse.util.logcontext import ( | ||||
|     PreserveLoggingContext, make_deferred_yieldable | ||||
|     PreserveLoggingContext, make_deferred_yieldable, | ||||
| ) | ||||
| from synapse.util.logutils import log_function | ||||
| from synapse.util.metrics import Measure | ||||
|  | @ -30,16 +34,8 @@ from synapse.api.constants import EventTypes | |||
| from synapse.api.errors import SynapseError | ||||
| from synapse.util.caches.descriptors import cached, cachedInlineCallbacks | ||||
| from synapse.types import get_domain_from_id | ||||
| 
 | ||||
| from canonicaljson import encode_canonical_json | ||||
| from collections import deque, namedtuple, OrderedDict | ||||
| from functools import wraps | ||||
| 
 | ||||
| import synapse.metrics | ||||
| 
 | ||||
| import logging | ||||
| import simplejson as json | ||||
| 
 | ||||
| # these are only included to make the type annotations work | ||||
| from synapse.events import EventBase    # noqa: F401 | ||||
| from synapse.events.snapshot import EventContext   # noqa: F401 | ||||
|  | @ -71,10 +67,7 @@ state_delta_reuse_delta_counter = metrics.register_counter( | |||
| 
 | ||||
| 
 | ||||
| def encode_json(json_object): | ||||
|     if USE_FROZEN_DICTS: | ||||
|         return encode_canonical_json(json_object) | ||||
|     else: | ||||
|         return json.dumps(json_object, ensure_ascii=False) | ||||
|     return frozendict_json_encoder.encode(json_object) | ||||
| 
 | ||||
| 
 | ||||
| class _EventPeristenceQueue(object): | ||||
|  |  | |||
|  | @ -1,5 +1,6 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2017 Vector Creations Ltd | ||||
| # Copyright 2018 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
|  | @ -29,6 +30,24 @@ _DEFAULT_ROLE_ID = "" | |||
| 
 | ||||
| 
 | ||||
| class GroupServerStore(SQLBaseStore): | ||||
|     def set_group_join_policy(self, group_id, join_policy): | ||||
|         """Set the join policy of a group. | ||||
| 
 | ||||
|         join_policy can be one of: | ||||
|          * "invite" | ||||
|          * "open" | ||||
|         """ | ||||
|         return self._simple_update_one( | ||||
|             table="groups", | ||||
|             keyvalues={ | ||||
|                 "group_id": group_id, | ||||
|             }, | ||||
|             updatevalues={ | ||||
|                 "join_policy": join_policy, | ||||
|             }, | ||||
|             desc="set_group_join_policy", | ||||
|         ) | ||||
| 
 | ||||
|     def get_group(self, group_id): | ||||
|         return self._simple_select_one( | ||||
|             table="groups", | ||||
|  | @ -36,10 +55,11 @@ class GroupServerStore(SQLBaseStore): | |||
|                 "group_id": group_id, | ||||
|             }, | ||||
|             retcols=( | ||||
|                 "name", "short_description", "long_description", "avatar_url", "is_public" | ||||
|                 "name", "short_description", "long_description", | ||||
|                 "avatar_url", "is_public", "join_policy", | ||||
|             ), | ||||
|             allow_none=True, | ||||
|             desc="is_user_in_group", | ||||
|             desc="get_group", | ||||
|         ) | ||||
| 
 | ||||
|     def get_users_in_group(self, group_id, include_private=False): | ||||
|  |  | |||
|  | @ -1,5 +1,6 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2014 - 2016 OpenMarket Ltd | ||||
| # Copyright 2018 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
|  | @ -25,7 +26,7 @@ logger = logging.getLogger(__name__) | |||
| 
 | ||||
| # Remember to update this number every time a change is made to database | ||||
| # schema files, so the users will be informed on server restarts. | ||||
| SCHEMA_VERSION = 47 | ||||
| SCHEMA_VERSION = 48 | ||||
| 
 | ||||
| dir_path = os.path.abspath(os.path.dirname(__file__)) | ||||
| 
 | ||||
|  |  | |||
|  | @ -594,7 +594,8 @@ class RoomStore(RoomWorkerStore, SearchStore): | |||
| 
 | ||||
|         while next_token: | ||||
|             sql = """ | ||||
|                 SELECT stream_ordering, content FROM events | ||||
|                 SELECT stream_ordering, json FROM events | ||||
|                 JOIN event_json USING (event_id) | ||||
|                 WHERE room_id = ? | ||||
|                     AND stream_ordering < ? | ||||
|                     AND contains_url = ? AND outlier = ? | ||||
|  | @ -606,8 +607,8 @@ class RoomStore(RoomWorkerStore, SearchStore): | |||
|             next_token = None | ||||
|             for stream_ordering, content_json in txn: | ||||
|                 next_token = stream_ordering | ||||
|                 content = json.loads(content_json) | ||||
| 
 | ||||
|                 event_json = json.loads(content_json) | ||||
|                 content = event_json["content"] | ||||
|                 content_url = content.get("url") | ||||
|                 thumbnail_url = content.get("info", {}).get("thumbnail_url") | ||||
| 
 | ||||
|  |  | |||
|  | @ -645,8 +645,9 @@ class RoomMemberStore(RoomMemberWorkerStore): | |||
| 
 | ||||
|         def add_membership_profile_txn(txn): | ||||
|             sql = (""" | ||||
|                 SELECT stream_ordering, event_id, events.room_id, content | ||||
|                 SELECT stream_ordering, event_id, events.room_id, event_json.json | ||||
|                 FROM events | ||||
|                 INNER JOIN event_json USING (event_id) | ||||
|                 INNER JOIN room_memberships USING (event_id) | ||||
|                 WHERE ? <= stream_ordering AND stream_ordering < ? | ||||
|                 AND type = 'm.room.member' | ||||
|  | @ -667,7 +668,8 @@ class RoomMemberStore(RoomMemberWorkerStore): | |||
|                 event_id = row["event_id"] | ||||
|                 room_id = row["room_id"] | ||||
|                 try: | ||||
|                     content = json.loads(row["content"]) | ||||
|                     event_json = json.loads(row["json"]) | ||||
|                     content = event_json['content'] | ||||
|                 except Exception: | ||||
|                     continue | ||||
| 
 | ||||
|  |  | |||
|  | @ -12,9 +12,10 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import json | ||||
| import logging | ||||
| 
 | ||||
| import simplejson as json | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -0,0 +1,17 @@ | |||
| /* Copyright 2018 New Vector Ltd | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *    http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| INSERT into background_updates (update_name, progress_json) | ||||
|     VALUES ('user_ips_last_seen_index', '{}'); | ||||
							
								
								
									
										22
									
								
								synapse/storage/schema/delta/48/groups_joinable.sql
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								synapse/storage/schema/delta/48/groups_joinable.sql
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,22 @@ | |||
| /* Copyright 2018 New Vector Ltd | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *    http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| /*  | ||||
|  * This isn't a real ENUM because sqlite doesn't support it | ||||
|  * and we use a default of NULL for inserted rows and interpret | ||||
|  * NULL at the python store level as necessary so that existing | ||||
|  * rows are given the correct default policy. | ||||
|  */ | ||||
| ALTER TABLE groups ADD COLUMN join_policy TEXT NOT NULL DEFAULT 'invite'; | ||||
|  | @ -75,8 +75,9 @@ class SearchStore(BackgroundUpdateStore): | |||
| 
 | ||||
|         def reindex_search_txn(txn): | ||||
|             sql = ( | ||||
|                 "SELECT stream_ordering, event_id, room_id, type, content, " | ||||
|                 "SELECT stream_ordering, event_id, room_id, type, json, " | ||||
|                 " origin_server_ts FROM events" | ||||
|                 " JOIN event_json USING (event_id)" | ||||
|                 " WHERE ? <= stream_ordering AND stream_ordering < ?" | ||||
|                 " AND (%s)" | ||||
|                 " ORDER BY stream_ordering DESC" | ||||
|  | @ -104,7 +105,8 @@ class SearchStore(BackgroundUpdateStore): | |||
|                     stream_ordering = row["stream_ordering"] | ||||
|                     origin_server_ts = row["origin_server_ts"] | ||||
|                     try: | ||||
|                         content = json.loads(row["content"]) | ||||
|                         event_json = json.loads(row["json"]) | ||||
|                         content = event_json["content"] | ||||
|                     except Exception: | ||||
|                         continue | ||||
| 
 | ||||
|  |  | |||
|  | @ -667,7 +667,7 @@ class UserDirectoryStore(SQLBaseStore): | |||
|             # The array of numbers are the weights for the various part of the | ||||
|             # search: (domain, _, display name, localpart) | ||||
|             sql = """ | ||||
|                 SELECT d.user_id, display_name, avatar_url | ||||
|                 SELECT d.user_id AS user_id, display_name, avatar_url | ||||
|                 FROM user_directory_search | ||||
|                 INNER JOIN user_directory AS d USING (user_id) | ||||
|                 %s | ||||
|  | @ -702,7 +702,7 @@ class UserDirectoryStore(SQLBaseStore): | |||
|             search_query = _parse_query_sqlite(search_term) | ||||
| 
 | ||||
|             sql = """ | ||||
|                 SELECT d.user_id, display_name, avatar_url | ||||
|                 SELECT d.user_id AS user_id, display_name, avatar_url | ||||
|                 FROM user_directory_search | ||||
|                 INNER JOIN user_directory AS d USING (user_id) | ||||
|                 %s | ||||
|  |  | |||
|  | @ -1,5 +1,6 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2015, 2016 OpenMarket Ltd | ||||
| # Copyright 2018 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
|  | @ -39,12 +40,11 @@ _CacheSentinel = object() | |||
| 
 | ||||
| class CacheEntry(object): | ||||
|     __slots__ = [ | ||||
|         "deferred", "sequence", "callbacks", "invalidated" | ||||
|         "deferred", "callbacks", "invalidated" | ||||
|     ] | ||||
| 
 | ||||
|     def __init__(self, deferred, sequence, callbacks): | ||||
|     def __init__(self, deferred, callbacks): | ||||
|         self.deferred = deferred | ||||
|         self.sequence = sequence | ||||
|         self.callbacks = set(callbacks) | ||||
|         self.invalidated = False | ||||
| 
 | ||||
|  | @ -62,7 +62,6 @@ class Cache(object): | |||
|         "max_entries", | ||||
|         "name", | ||||
|         "keylen", | ||||
|         "sequence", | ||||
|         "thread", | ||||
|         "metrics", | ||||
|         "_pending_deferred_cache", | ||||
|  | @ -80,7 +79,6 @@ class Cache(object): | |||
| 
 | ||||
|         self.name = name | ||||
|         self.keylen = keylen | ||||
|         self.sequence = 0 | ||||
|         self.thread = None | ||||
|         self.metrics = register_cache(name, self.cache) | ||||
| 
 | ||||
|  | @ -113,11 +111,10 @@ class Cache(object): | |||
|         callbacks = [callback] if callback else [] | ||||
|         val = self._pending_deferred_cache.get(key, _CacheSentinel) | ||||
|         if val is not _CacheSentinel: | ||||
|             if val.sequence == self.sequence: | ||||
|                 val.callbacks.update(callbacks) | ||||
|                 if update_metrics: | ||||
|                     self.metrics.inc_hits() | ||||
|                 return val.deferred | ||||
|             val.callbacks.update(callbacks) | ||||
|             if update_metrics: | ||||
|                 self.metrics.inc_hits() | ||||
|             return val.deferred | ||||
| 
 | ||||
|         val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) | ||||
|         if val is not _CacheSentinel: | ||||
|  | @ -137,12 +134,9 @@ class Cache(object): | |||
|         self.check_thread() | ||||
|         entry = CacheEntry( | ||||
|             deferred=value, | ||||
|             sequence=self.sequence, | ||||
|             callbacks=callbacks, | ||||
|         ) | ||||
| 
 | ||||
|         entry.callbacks.update(callbacks) | ||||
| 
 | ||||
|         existing_entry = self._pending_deferred_cache.pop(key, None) | ||||
|         if existing_entry: | ||||
|             existing_entry.invalidate() | ||||
|  | @ -150,13 +144,25 @@ class Cache(object): | |||
|         self._pending_deferred_cache[key] = entry | ||||
| 
 | ||||
|         def shuffle(result): | ||||
|             if self.sequence == entry.sequence: | ||||
|                 existing_entry = self._pending_deferred_cache.pop(key, None) | ||||
|                 if existing_entry is entry: | ||||
|                     self.cache.set(key, result, entry.callbacks) | ||||
|                 else: | ||||
|                     entry.invalidate() | ||||
|             existing_entry = self._pending_deferred_cache.pop(key, None) | ||||
|             if existing_entry is entry: | ||||
|                 self.cache.set(key, result, entry.callbacks) | ||||
|             else: | ||||
|                 # oops, the _pending_deferred_cache has been updated since | ||||
|                 # we started our query, so we are out of date. | ||||
|                 # | ||||
|                 # Better put back whatever we took out. (We do it this way | ||||
|                 # round, rather than peeking into the _pending_deferred_cache | ||||
|                 # and then removing on a match, to make the common case faster) | ||||
|                 if existing_entry is not None: | ||||
|                     self._pending_deferred_cache[key] = existing_entry | ||||
| 
 | ||||
|                 # we're not going to put this entry into the cache, so need | ||||
|                 # to make sure that the invalidation callbacks are called. | ||||
|                 # That was probably done when _pending_deferred_cache was | ||||
|                 # updated, but it's possible that `set` was called without | ||||
|                 # `invalidate` being previously called, in which case it may | ||||
|                 # not have been. Either way, let's double-check now. | ||||
|                 entry.invalidate() | ||||
|             return result | ||||
| 
 | ||||
|  | @ -168,25 +174,29 @@ class Cache(object): | |||
| 
 | ||||
|     def invalidate(self, key): | ||||
|         self.check_thread() | ||||
|         self.cache.pop(key, None) | ||||
| 
 | ||||
|         # Increment the sequence number so that any SELECT statements that | ||||
|         # raced with the INSERT don't update the cache (SYN-369) | ||||
|         self.sequence += 1 | ||||
|         # if we have a pending lookup for this key, remove it from the | ||||
|         # _pending_deferred_cache, which will (a) stop it being returned | ||||
|         # for future queries and (b) stop it being persisted as a proper entry | ||||
|         # in self.cache. | ||||
|         entry = self._pending_deferred_cache.pop(key, None) | ||||
| 
 | ||||
|         # run the invalidation callbacks now, rather than waiting for the | ||||
|         # deferred to resolve. | ||||
|         if entry: | ||||
|             entry.invalidate() | ||||
| 
 | ||||
|         self.cache.pop(key, None) | ||||
| 
 | ||||
|     def invalidate_many(self, key): | ||||
|         self.check_thread() | ||||
|         if not isinstance(key, tuple): | ||||
|             raise TypeError( | ||||
|                 "The cache key must be a tuple not %r" % (type(key),) | ||||
|             ) | ||||
|         self.sequence += 1 | ||||
|         self.cache.del_multi(key) | ||||
| 
 | ||||
|         # if we have a pending lookup for this key, remove it from the | ||||
|         # _pending_deferred_cache, as above | ||||
|         entry_dict = self._pending_deferred_cache.pop(key, None) | ||||
|         if entry_dict is not None: | ||||
|             for entry in iterate_tree_cache_entry(entry_dict): | ||||
|  | @ -194,8 +204,10 @@ class Cache(object): | |||
| 
 | ||||
|     def invalidate_all(self): | ||||
|         self.check_thread() | ||||
|         self.sequence += 1 | ||||
|         self.cache.clear() | ||||
|         for entry in self._pending_deferred_cache.itervalues(): | ||||
|             entry.invalidate() | ||||
|         self._pending_deferred_cache.clear() | ||||
| 
 | ||||
| 
 | ||||
| class _CacheDescriptorBase(object): | ||||
|  |  | |||
|  | @ -14,6 +14,7 @@ | |||
| # limitations under the License. | ||||
| 
 | ||||
| from frozendict import frozendict | ||||
| import simplejson as json | ||||
| 
 | ||||
| 
 | ||||
| def freeze(o): | ||||
|  | @ -49,3 +50,21 @@ def unfreeze(o): | |||
|         pass | ||||
| 
 | ||||
|     return o | ||||
| 
 | ||||
| 
 | ||||
| def _handle_frozendict(obj): | ||||
|     """Helper for EventEncoder. Makes frozendicts serializable by returning | ||||
|     the underlying dict | ||||
|     """ | ||||
|     if type(obj) is frozendict: | ||||
|         # fishing the protected dict out of the object is a bit nasty, | ||||
|         # but we don't really want the overhead of copying the dict. | ||||
|         return obj._dict | ||||
|     raise TypeError('Object of type %s is not JSON serializable' % | ||||
|                     obj.__class__.__name__) | ||||
| 
 | ||||
| 
 | ||||
| # A JSONEncoder which is capable of encoding frozendics without barfing | ||||
| frozendict_json_encoder = json.JSONEncoder( | ||||
|     default=_handle_frozendict, | ||||
| ) | ||||
|  |  | |||
|  | @ -1,5 +1,6 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2016 OpenMarket Ltd | ||||
| # Copyright 2018 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
|  | @ -12,6 +13,7 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| from functools import partial | ||||
| import logging | ||||
| 
 | ||||
| import mock | ||||
|  | @ -25,6 +27,50 @@ from tests import unittest | |||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class CacheTestCase(unittest.TestCase): | ||||
|     def test_invalidate_all(self): | ||||
|         cache = descriptors.Cache("testcache") | ||||
| 
 | ||||
|         callback_record = [False, False] | ||||
| 
 | ||||
|         def record_callback(idx): | ||||
|             callback_record[idx] = True | ||||
| 
 | ||||
|         # add a couple of pending entries | ||||
|         d1 = defer.Deferred() | ||||
|         cache.set("key1", d1, partial(record_callback, 0)) | ||||
| 
 | ||||
|         d2 = defer.Deferred() | ||||
|         cache.set("key2", d2, partial(record_callback, 1)) | ||||
| 
 | ||||
|         # lookup should return the deferreds | ||||
|         self.assertIs(cache.get("key1"), d1) | ||||
|         self.assertIs(cache.get("key2"), d2) | ||||
| 
 | ||||
|         # let one of the lookups complete | ||||
|         d2.callback("result2") | ||||
|         self.assertEqual(cache.get("key2"), "result2") | ||||
| 
 | ||||
|         # now do the invalidation | ||||
|         cache.invalidate_all() | ||||
| 
 | ||||
|         # lookup should return none | ||||
|         self.assertIsNone(cache.get("key1", None)) | ||||
|         self.assertIsNone(cache.get("key2", None)) | ||||
| 
 | ||||
|         # both callbacks should have been callbacked | ||||
|         self.assertTrue( | ||||
|             callback_record[0], "Invalidation callback for key1 not called", | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             callback_record[1], "Invalidation callback for key2 not called", | ||||
|         ) | ||||
| 
 | ||||
|         # letting the other lookup complete should do nothing | ||||
|         d1.callback("result1") | ||||
|         self.assertIsNone(cache.get("key1", None)) | ||||
| 
 | ||||
| 
 | ||||
| class DescriptorTestCase(unittest.TestCase): | ||||
|     @defer.inlineCallbacks | ||||
|     def test_cache(self): | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue