0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-22 09:24:00 +01:00

Add bulk group publicised lookup API

This commit is contained in:
Erik Johnston 2017-08-09 13:36:22 +01:00
parent b880ff190a
commit ef8e578677
5 changed files with 142 additions and 0 deletions

View file

@ -812,3 +812,18 @@ class TransportLayerClient(object):
args={"requester_user_id": requester_user_id},
ignore_backoff=True,
)
def bulk_get_publicised_groups(self, destination, user_ids):
"""Get the groups a list of users are publicising
"""
path = PREFIX + "/get_groups_publicised"
content = {"user_ids": user_ids}
return self.client.post_json(
destination=destination,
path=path,
data=content,
ignore_backoff=True,
)

View file

@ -1050,6 +1050,22 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
defer.returnValue((200, resp))
class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
"""Get roles in a group
"""
PATH = (
"/get_groups_publicised$"
)
@defer.inlineCallbacks
def on_POST(self, origin, content, query):
resp = yield self.handler.bulk_get_publicised_groups(
content["user_ids"], proxy=False,
)
defer.returnValue((200, resp))
FEDERATION_SERVLET_CLASSES = (
FederationSendServlet,
FederationPullServlet,
@ -1102,6 +1118,7 @@ GROUP_SERVER_SERVLET_CLASSES = (
GROUP_LOCAL_SERVLET_CLASSES = (
FederationGroupsLocalInviteServlet,
FederationGroupsRemoveLocalUserServlet,
FederationGroupsBulkPublicisedServlet,
)

View file

@ -313,3 +313,45 @@ class GroupsLocalHandler(object):
def get_joined_groups(self, user_id):
group_ids = yield self.store.get_joined_groups(user_id)
defer.returnValue({"groups": group_ids})
@defer.inlineCallbacks
def get_publicised_groups_for_user(self, user_id):
if self.hs.is_mine_id(user_id):
result = yield self.store.get_publicised_groups_for_user(user_id)
defer.returnValue({"groups": result})
else:
result = yield self.transport_client.get_publicised_groups_for_user(
get_domain_from_id(user_id), user_id
)
# TODO: Verify attestations
defer.returnValue(result)
@defer.inlineCallbacks
def bulk_get_publicised_groups(self, user_ids, proxy=True):
destinations = {}
locals = []
for user_id in user_ids:
if self.hs.is_mine_id(user_id):
locals.append(user_id)
else:
destinations.setdefault(
get_domain_from_id(user_id), []
).append(user_id)
if not proxy and destinations:
raise SynapseError(400, "Some user_ids are not local")
results = {}
for destination, dest_user_ids in destinations.iteritems():
r = yield self.transport_client.bulk_get_publicised_groups(
destination, dest_user_ids,
)
results.update(r)
for uid in locals:
results[uid] = yield self.store.get_publicised_groups_for_user(
uid
)
defer.returnValue({"users": results})

View file

@ -584,6 +584,59 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
defer.returnValue((200, {}))
class PublicisedGroupsForUserServlet(RestServlet):
"""Get the list of groups a user is advertising
"""
PATTERNS = client_v2_patterns(
"/publicised_groups/(?P<user_id>[^/]*)$"
)
def __init__(self, hs):
super(PublicisedGroupsForUserServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
yield self.auth.get_user_by_req(request)
result = yield self.groups_handler.get_publicised_groups_for_user(
user_id
)
defer.returnValue((200, result))
class PublicisedGroupsForUsersServlet(RestServlet):
"""Get the list of groups a user is advertising
"""
PATTERNS = client_v2_patterns(
"/publicised_groups$"
)
def __init__(self, hs):
super(PublicisedGroupsForUsersServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks
def on_POST(self, request):
yield self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
user_ids = content["user_ids"]
result = yield self.groups_handler.bulk_get_publicised_groups(
user_ids
)
defer.returnValue((200, result))
class GroupsForUserServlet(RestServlet):
"""Get all groups the logged in user is joined to
"""
@ -627,3 +680,4 @@ def register_servlets(hs, http_server):
GroupRolesServlet(hs).register(http_server)
GroupSelfUpdatePublicityServlet(hs).register(http_server)
GroupSummaryUsersRoleServlet(hs).register(http_server)
PublicisedGroupsForUserServlet(hs).register(http_server)

View file

@ -835,6 +835,20 @@ class GroupServerStore(SQLBaseStore):
desc="add_room_to_group",
)
def get_publicised_groups_for_user(self, user_id):
"""Get all groups a user is publicising
"""
return self._simple_select_onecol(
table="local_group_membership",
keyvalues={
"user_id": user_id,
"membership": "join",
"is_publicised": True,
},
retcol="group_id",
desc="get_publicised_groups_for_user",
)
def update_group_publicity(self, group_id, user_id, publicise):
"""Update whether the user is publicising their membership of the group
"""