0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-15 22:42:23 +01:00

Merge pull request #1676 from matrix-org/erikj/room_list

Add new API appservice specific public room list
This commit is contained in:
Erik Johnston 2016-12-12 17:00:10 +00:00 committed by GitHub
commit 1574b839e0
15 changed files with 399 additions and 42 deletions

View file

@ -89,6 +89,9 @@ class ApplicationService(object):
self.namespaces = self._check_namespaces(namespaces) self.namespaces = self._check_namespaces(namespaces)
self.id = id self.id = id
if "|" in self.id:
raise Exception("application service ID cannot contain '|' character")
# .protocols is a publicly visible field # .protocols is a publicly visible field
if protocols: if protocols:
self.protocols = set(protocols) self.protocols = set(protocols)

View file

@ -19,6 +19,7 @@ from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.types import ThirdPartyInstanceID
import logging import logging
import urllib import urllib
@ -177,6 +178,13 @@ class ApplicationServiceApi(SimpleHttpClient):
" valid result", uri) " valid result", uri)
defer.returnValue(None) defer.returnValue(None)
for instance in info.get("instances", []):
network_id = instance.get("network_id", None)
if network_id is not None:
instance["instance_id"] = ThirdPartyInstanceID(
service.id, network_id,
).to_string()
defer.returnValue(info) defer.returnValue(info)
except Exception as ex: except Exception as ex:
logger.warning("query_3pe_protocol to %s threw exception %s", logger.warning("query_3pe_protocol to %s threw exception %s",

View file

@ -655,12 +655,15 @@ class FederationClient(FederationBase):
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")
def get_public_rooms(self, destination, limit=None, since_token=None, def get_public_rooms(self, destination, limit=None, since_token=None,
search_filter=None): search_filter=None, include_all_networks=False,
third_party_instance_id=None):
if destination == self.server_name: if destination == self.server_name:
return return
return self.transport_layer.get_public_rooms( return self.transport_layer.get_public_rooms(
destination, limit, since_token, search_filter destination, limit, since_token, search_filter,
include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -249,10 +249,15 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_public_rooms(self, remote_server, limit, since_token, def get_public_rooms(self, remote_server, limit, since_token,
search_filter=None): search_filter=None, include_all_networks=False,
third_party_instance_id=None):
path = PREFIX + "/publicRooms" path = PREFIX + "/publicRooms"
args = {} args = {
"include_all_networks": "true" if include_all_networks else "false",
}
if third_party_instance_id:
args["third_party_instance_id"] = third_party_instance_id,
if limit: if limit:
args["limit"] = [str(limit)] args["limit"] = [str(limit)]
if since_token: if since_token:

View file

@ -20,9 +20,11 @@ from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import ( from synapse.http.servlet import (
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args, parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
parse_boolean_from_args,
) )
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from synapse.types import ThirdPartyInstanceID
import functools import functools
import logging import logging
@ -558,8 +560,23 @@ class PublicRoomList(BaseFederationServlet):
def on_GET(self, origin, content, query): def on_GET(self, origin, content, query):
limit = parse_integer_from_args(query, "limit", 0) limit = parse_integer_from_args(query, "limit", 0)
since_token = parse_string_from_args(query, "since", None) since_token = parse_string_from_args(query, "since", None)
include_all_networks = parse_boolean_from_args(
query, "include_all_networks", False
)
third_party_instance_id = parse_string_from_args(
query, "third_party_instance_id", None
)
if include_all_networks:
network_tuple = None
elif third_party_instance_id:
network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
else:
network_tuple = ThirdPartyInstanceID(None, None)
data = yield self.room_list_handler.get_local_public_room_list( data = yield self.room_list_handler.get_local_public_room_list(
limit, since_token limit, since_token,
network_tuple=network_tuple
) )
defer.returnValue((200, data)) defer.returnValue((200, data))

View file

@ -339,3 +339,22 @@ class DirectoryHandler(BaseHandler):
yield self.auth.check_can_change_room_list(room_id, requester.user) yield self.auth.check_can_change_room_list(room_id, requester.user)
yield self.store.set_room_is_public(room_id, visibility == "public") yield self.store.set_room_is_public(room_id, visibility == "public")
@defer.inlineCallbacks
def edit_published_appservice_room_list(self, appservice_id, network_id,
room_id, visibility):
"""Add or remove a room from the appservice/network specific public
room list.
Args:
appservice_id (str): ID of the appservice that owns the list
network_id (str): The ID of the network the list is associated with
room_id (str)
visibility (str): either "public" or "private"
"""
if visibility not in ["public", "private"]:
raise SynapseError(400, "Invalid visibility setting")
yield self.store.set_room_is_public_appservice(
room_id, appservice_id, network_id, visibility == "public"
)

View file

@ -22,6 +22,7 @@ from synapse.api.constants import (
) )
from synapse.util.async import concurrently_execute from synapse.util.async import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.types import ThirdPartyInstanceID
from collections import namedtuple from collections import namedtuple
from unpaddedbase64 import encode_base64, decode_base64 from unpaddedbase64 import encode_base64, decode_base64
@ -34,6 +35,10 @@ logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
# This is used to indicate we should only return rooms published to the main list.
EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler(BaseHandler): class RoomListHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(RoomListHandler, self).__init__(hs) super(RoomListHandler, self).__init__(hs)
@ -41,10 +46,28 @@ class RoomListHandler(BaseHandler):
self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000) self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000)
def get_local_public_room_list(self, limit=None, since_token=None, def get_local_public_room_list(self, limit=None, since_token=None,
search_filter=None): search_filter=None,
if search_filter: network_tuple=EMTPY_THIRD_PARTY_ID,):
# We explicitly don't bother caching searches. """Generate a local public room list.
return self._get_public_room_list(limit, since_token, search_filter)
There are multiple different lists: the main one plus one per third
party network. A client can ask for a specific list or to return all.
Args:
limit (int)
since_token (str)
search_filter (dict)
network_tuple (ThirdPartyInstanceID): Which public list to use.
This can be (None, None) to indicate the main list, or a particular
appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists.
"""
if search_filter or network_tuple is not (None, None):
# We explicitly don't bother caching searches or requests for
# appservice specific lists.
return self._get_public_room_list(
limit, since_token, search_filter, network_tuple=network_tuple,
)
result = self.response_cache.get((limit, since_token)) result = self.response_cache.get((limit, since_token))
if not result: if not result:
@ -56,7 +79,8 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None, def _get_public_room_list(self, limit=None, since_token=None,
search_filter=None): search_filter=None,
network_tuple=EMTPY_THIRD_PARTY_ID,):
if since_token and since_token != "END": if since_token and since_token != "END":
since_token = RoomListNextBatch.from_token(since_token) since_token = RoomListNextBatch.from_token(since_token)
else: else:
@ -73,14 +97,15 @@ class RoomListHandler(BaseHandler):
current_public_id = yield self.store.get_current_public_room_stream_id() current_public_id = yield self.store.get_current_public_room_stream_id()
public_room_stream_id = since_token.public_room_stream_id public_room_stream_id = since_token.public_room_stream_id
newly_visible, newly_unpublished = yield self.store.get_public_room_changes( newly_visible, newly_unpublished = yield self.store.get_public_room_changes(
public_room_stream_id, current_public_id public_room_stream_id, current_public_id,
network_tuple=network_tuple,
) )
else: else:
stream_token = yield self.store.get_room_max_stream_ordering() stream_token = yield self.store.get_room_max_stream_ordering()
public_room_stream_id = yield self.store.get_current_public_room_stream_id() public_room_stream_id = yield self.store.get_current_public_room_stream_id()
room_ids = yield self.store.get_public_room_ids_at_stream_id( room_ids = yield self.store.get_public_room_ids_at_stream_id(
public_room_stream_id public_room_stream_id, network_tuple=network_tuple,
) )
# We want to return rooms in a particular order: the number of joined # We want to return rooms in a particular order: the number of joined
@ -311,7 +336,8 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_remote_public_room_list(self, server_name, limit=None, since_token=None, def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
search_filter=None): search_filter=None, include_all_networks=False,
third_party_instance_id=None,):
if search_filter: if search_filter:
# We currently don't support searching across federation, so we have # We currently don't support searching across federation, so we have
# to do it manually without pagination # to do it manually without pagination
@ -320,6 +346,8 @@ class RoomListHandler(BaseHandler):
res = yield self._get_remote_list_cached( res = yield self._get_remote_list_cached(
server_name, limit=limit, since_token=since_token, server_name, limit=limit, since_token=since_token,
include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
) )
if search_filter: if search_filter:
@ -332,22 +360,30 @@ class RoomListHandler(BaseHandler):
defer.returnValue(res) defer.returnValue(res)
def _get_remote_list_cached(self, server_name, limit=None, since_token=None, def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
search_filter=None): search_filter=None, include_all_networks=False,
third_party_instance_id=None,):
repl_layer = self.hs.get_replication_layer() repl_layer = self.hs.get_replication_layer()
if search_filter: if search_filter:
# We can't cache when asking for search # We can't cache when asking for search
return repl_layer.get_public_rooms( return repl_layer.get_public_rooms(
server_name, limit=limit, since_token=since_token, server_name, limit=limit, since_token=since_token,
search_filter=search_filter, search_filter=search_filter, include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
) )
result = self.remote_response_cache.get((server_name, limit, since_token)) key = (
server_name, limit, since_token, include_all_networks,
third_party_instance_id,
)
result = self.remote_response_cache.get(key)
if not result: if not result:
result = self.remote_response_cache.set( result = self.remote_response_cache.set(
(server_name, limit, since_token), key,
repl_layer.get_public_rooms( repl_layer.get_public_rooms(
server_name, limit=limit, since_token=since_token, server_name, limit=limit, since_token=since_token,
search_filter=search_filter, search_filter=search_filter,
include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
) )
) )
return result return result

View file

@ -78,12 +78,16 @@ def parse_boolean(request, name, default=None, required=False):
parameter is present and not one of "true" or "false". parameter is present and not one of "true" or "false".
""" """
if name in request.args: return parse_boolean_from_args(request.args, name, default, required)
def parse_boolean_from_args(args, name, default=None, required=False):
if name in args:
try: try:
return { return {
"true": True, "true": True,
"false": False, "false": False,
}[request.args[name][0]] }[args[name][0]]
except: except:
message = ( message = (
"Boolean query parameter %r must be one of" "Boolean query parameter %r must be one of"

View file

@ -475,7 +475,7 @@ class ReplicationResource(Resource):
) )
upto_token = _position_from_rows(public_rooms_rows, current_position) upto_token = _position_from_rows(public_rooms_rows, current_position)
writer.write_header_and_rows("public_rooms", public_rooms_rows, ( writer.write_header_and_rows("public_rooms", public_rooms_rows, (
"position", "room_id", "visibility" "position", "room_id", "visibility", "appservice_id", "network_id",
), position=upto_token) ), position=upto_token)
def federation(self, writer, current_token, limit, request_streams, federation_ack): def federation(self, writer, current_token, limit, request_streams, federation_ack):

View file

@ -15,6 +15,7 @@
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.room import RoomStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
@ -30,7 +31,7 @@ class RoomStore(BaseSlavedStore):
DataStore.get_current_public_room_stream_id.__func__ DataStore.get_current_public_room_stream_id.__func__
) )
get_public_room_ids_at_stream_id = ( get_public_room_ids_at_stream_id = (
DataStore.get_public_room_ids_at_stream_id.__func__ RoomStore.__dict__["get_public_room_ids_at_stream_id"]
) )
get_public_room_ids_at_stream_id_txn = ( get_public_room_ids_at_stream_id_txn = (
DataStore.get_public_room_ids_at_stream_id_txn.__func__ DataStore.get_public_room_ids_at_stream_id_txn.__func__

View file

@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
ClientDirectoryServer(hs).register(http_server) ClientDirectoryServer(hs).register(http_server)
ClientDirectoryListServer(hs).register(http_server) ClientDirectoryListServer(hs).register(http_server)
ClientAppserviceDirectoryListServer(hs).register(http_server)
class ClientDirectoryServer(ClientV1RestServlet): class ClientDirectoryServer(ClientV1RestServlet):
@ -184,3 +185,36 @@ class ClientDirectoryListServer(ClientV1RestServlet):
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
class ClientAppserviceDirectoryListServer(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$"
)
def __init__(self, hs):
super(ClientAppserviceDirectoryListServer, self).__init__(hs)
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
def on_PUT(self, request, network_id, room_id):
content = parse_json_object_from_request(request)
visibility = content.get("visibility", "public")
return self._edit(request, network_id, room_id, visibility)
def on_DELETE(self, request, network_id, room_id):
return self._edit(request, network_id, room_id, "private")
@defer.inlineCallbacks
def _edit(self, request, network_id, room_id, visibility):
requester = yield self.auth.get_user_by_req(request)
if not requester.app_service:
raise AuthError(
403, "Only appservices can edit the appservice published room list"
)
yield self.handlers.directory_handler.edit_published_appservice_room_list(
requester.app_service.id, network_id, room_id, visibility,
)
defer.returnValue((200, {}))

View file

@ -21,7 +21,7 @@ from synapse.api.errors import SynapseError, Codes, AuthError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias from synapse.types import UserID, RoomID, RoomAlias, ThirdPartyInstanceID
from synapse.events.utils import serialize_event, format_event_for_client_v2 from synapse.events.utils import serialize_event, format_event_for_client_v2
from synapse.http.servlet import ( from synapse.http.servlet import (
parse_json_object_from_request, parse_string, parse_integer parse_json_object_from_request, parse_string, parse_integer
@ -321,6 +321,20 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
since_token = content.get("since", None) since_token = content.get("since", None)
search_filter = content.get("filter", None) search_filter = content.get("filter", None)
include_all_networks = content.get("include_all_networks", False)
third_party_instance_id = content.get("third_party_instance_id", None)
if include_all_networks:
network_tuple = None
if third_party_instance_id is not None:
raise SynapseError(
400, "Can't use include_all_networks with an explicit network"
)
elif third_party_instance_id is None:
network_tuple = ThirdPartyInstanceID(None, None)
else:
network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
handler = self.hs.get_room_list_handler() handler = self.hs.get_room_list_handler()
if server: if server:
data = yield handler.get_remote_public_room_list( data = yield handler.get_remote_public_room_list(
@ -328,12 +342,15 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
limit=limit, limit=limit,
since_token=since_token, since_token=since_token,
search_filter=search_filter, search_filter=search_filter,
include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
) )
else: else:
data = yield handler.get_local_public_room_list( data = yield handler.get_local_public_room_list(
limit=limit, limit=limit,
since_token=since_token, since_token=since_token,
search_filter=search_filter, search_filter=search_filter,
network_tuple=network_tuple,
) )
defer.returnValue((200, data)) defer.returnValue((200, data))

View file

@ -16,6 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore from ._base import SQLBaseStore
from .engines import PostgresEngine, Sqlite3Engine from .engines import PostgresEngine, Sqlite3Engine
@ -106,7 +107,11 @@ class RoomStore(SQLBaseStore):
entries = self._simple_select_list_txn( entries = self._simple_select_list_txn(
txn, txn,
table="public_room_list_stream", table="public_room_list_stream",
keyvalues={"room_id": room_id}, keyvalues={
"room_id": room_id,
"appservice_id": None,
"network_id": None,
},
retcols=("stream_id", "visibility"), retcols=("stream_id", "visibility"),
) )
@ -124,6 +129,8 @@ class RoomStore(SQLBaseStore):
"stream_id": next_id, "stream_id": next_id,
"room_id": room_id, "room_id": room_id,
"visibility": is_public, "visibility": is_public,
"appservice_id": None,
"network_id": None,
} }
) )
@ -132,6 +139,87 @@ class RoomStore(SQLBaseStore):
"set_room_is_public", "set_room_is_public",
set_room_is_public_txn, next_id, set_room_is_public_txn, next_id,
) )
self.hs.get_notifier().on_new_replication_data()
@defer.inlineCallbacks
def set_room_is_public_appservice(self, room_id, appservice_id, network_id,
is_public):
"""Edit the appservice/network specific public room list.
Each appservice can have a number of published room lists associated
with them, keyed off of an appservice defined `network_id`, which
basically represents a single instance of a bridge to a third party
network.
Args:
room_id (str)
appservice_id (str)
network_id (str)
is_public (bool): Whether to publish or unpublish the room from the
list.
"""
def set_room_is_public_appservice_txn(txn, next_id):
if is_public:
try:
self._simple_insert_txn(
txn,
table="appservice_room_list",
values={
"appservice_id": appservice_id,
"network_id": network_id,
"room_id": room_id
},
)
except self.database_engine.module.IntegrityError:
# We've already inserted, nothing to do.
return
else:
self._simple_delete_txn(
txn,
table="appservice_room_list",
keyvalues={
"appservice_id": appservice_id,
"network_id": network_id,
"room_id": room_id
},
)
entries = self._simple_select_list_txn(
txn,
table="public_room_list_stream",
keyvalues={
"room_id": room_id,
"appservice_id": appservice_id,
"network_id": network_id,
},
retcols=("stream_id", "visibility"),
)
entries.sort(key=lambda r: r["stream_id"])
add_to_stream = True
if entries:
add_to_stream = bool(entries[-1]["visibility"]) != is_public
if add_to_stream:
self._simple_insert_txn(
txn,
table="public_room_list_stream",
values={
"stream_id": next_id,
"room_id": room_id,
"visibility": is_public,
"appservice_id": appservice_id,
"network_id": network_id,
}
)
with self._public_room_id_gen.get_next() as next_id:
yield self.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn, next_id,
)
self.hs.get_notifier().on_new_replication_data()
def get_public_room_ids(self): def get_public_room_ids(self):
return self._simple_select_onecol( return self._simple_select_onecol(
@ -259,38 +347,96 @@ class RoomStore(SQLBaseStore):
def get_current_public_room_stream_id(self): def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token() return self._public_room_id_gen.get_current_token()
def get_public_room_ids_at_stream_id(self, stream_id): @cached(num_args=2, max_entries=100)
def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
"""Get pulbic rooms for a particular list, or across all lists.
Args:
stream_id (int)
network_tuple (ThirdPartyInstanceID): The list to use (None, None)
means the main list, None means all lsits.
"""
return self.runInteraction( return self.runInteraction(
"get_public_room_ids_at_stream_id", "get_public_room_ids_at_stream_id",
self.get_public_room_ids_at_stream_id_txn, stream_id self.get_public_room_ids_at_stream_id_txn,
stream_id, network_tuple=network_tuple
) )
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id): def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
network_tuple):
return { return {
rm rm
for rm, vis in self.get_published_at_stream_id_txn(txn, stream_id).items() for rm, vis in self.get_published_at_stream_id_txn(
txn, stream_id, network_tuple=network_tuple
).items()
if vis if vis
} }
def get_published_at_stream_id_txn(self, txn, stream_id): def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
if network_tuple:
# We want to get from a particular list. No aggregation required.
sql = (""" sql = ("""
SELECT room_id, visibility FROM public_room_list_stream SELECT room_id, visibility FROM public_room_list_stream
INNER JOIN ( INNER JOIN (
SELECT room_id, max(stream_id) AS stream_id SELECT room_id, max(stream_id) AS stream_id
FROM public_room_list_stream FROM public_room_list_stream
WHERE stream_id <= ? WHERE stream_id <= ? %s
GROUP BY room_id GROUP BY room_id
) grouped USING (room_id, stream_id) ) grouped USING (room_id, stream_id)
""") """)
txn.execute(sql, (stream_id,)) if network_tuple.appservice_id is not None:
txn.execute(
sql % ("AND appservice_id = ? AND network_id = ?",),
(stream_id, network_tuple.appservice_id, network_tuple.network_id,)
)
else:
txn.execute(
sql % ("AND appservice_id IS NULL",),
(stream_id,)
)
return dict(txn.fetchall()) return dict(txn.fetchall())
else:
# We want to get from all lists, so we need to aggregate the results
def get_public_room_changes(self, prev_stream_id, new_stream_id): logger.info("Executing full list")
sql = ("""
SELECT room_id, visibility
FROM public_room_list_stream
INNER JOIN (
SELECT
room_id, max(stream_id) AS stream_id, appservice_id,
network_id
FROM public_room_list_stream
WHERE stream_id <= ?
GROUP BY room_id, appservice_id, network_id
) grouped USING (room_id, stream_id)
""")
txn.execute(
sql,
(stream_id,)
)
results = {}
# A room is visible if its visible on any list.
for room_id, visibility in txn.fetchall():
results[room_id] = bool(visibility) or results.get(room_id, False)
return results
def get_public_room_changes(self, prev_stream_id, new_stream_id,
network_tuple):
def get_public_room_changes_txn(txn): def get_public_room_changes_txn(txn):
then_rooms = self.get_public_room_ids_at_stream_id_txn(txn, prev_stream_id) then_rooms = self.get_public_room_ids_at_stream_id_txn(
txn, prev_stream_id, network_tuple
)
now_rooms_dict = self.get_published_at_stream_id_txn(txn, new_stream_id) now_rooms_dict = self.get_published_at_stream_id_txn(
txn, new_stream_id, network_tuple
)
now_rooms_visible = set( now_rooms_visible = set(
rm for rm, vis in now_rooms_dict.items() if vis rm for rm, vis in now_rooms_dict.items() if vis
@ -311,7 +457,8 @@ class RoomStore(SQLBaseStore):
def get_all_new_public_rooms(self, prev_id, current_id, limit): def get_all_new_public_rooms(self, prev_id, current_id, limit):
def get_all_new_public_rooms(txn): def get_all_new_public_rooms(txn):
sql = (""" sql = ("""
SELECT stream_id, room_id, visibility FROM public_room_list_stream SELECT stream_id, room_id, visibility, appservice_id, network_id
FROM public_room_list_stream
WHERE stream_id > ? AND stream_id <= ? WHERE stream_id > ? AND stream_id <= ?
ORDER BY stream_id ASC ORDER BY stream_id ASC
LIMIT ? LIMIT ?

View file

@ -0,0 +1,29 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE appservice_room_list(
appservice_id TEXT NOT NULL,
network_id TEXT NOT NULL,
room_id TEXT NOT NULL
);
-- Each appservice can have multiple published room lists associated with them,
-- keyed of a particular network_id
CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list(
appservice_id, network_id, room_id
);
ALTER TABLE public_room_list_stream ADD COLUMN appservice_id TEXT;
ALTER TABLE public_room_list_stream ADD COLUMN network_id TEXT;

View file

@ -274,3 +274,37 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "t%d-%d" % (self.topological, self.stream) return "t%d-%d" % (self.topological, self.stream)
else: else:
return "s%d" % (self.stream,) return "s%d" % (self.stream,)
class ThirdPartyInstanceID(
namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id"))
):
# Deny iteration because it will bite you if you try to create a singleton
# set by:
# users = set(user)
def __iter__(self):
raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
# Because this class is a namedtuple of strings, it is deeply immutable.
def __copy__(self):
return self
def __deepcopy__(self, memo):
return self
@classmethod
def from_string(cls, s):
bits = s.split("|", 2)
if len(bits) != 2:
raise SynapseError(400, "Invalid ID %r" % (s,))
return cls(appservice_id=bits[0], network_id=bits[1])
def to_string(self):
return "%s|%s" % (self.appservice_id, self.network_id,)
__str__ = to_string
@classmethod
def create(cls, appservice_id, network_id,):
return cls(appservice_id=appservice_id, network_id=network_id)