forked from MirrorHub/synapse
Add is_guest flag to users db to track whether a user is a guest user or not. Use this so we can run _filter_events_for_client when calculating event_push_actions.
This commit is contained in:
parent
eb03625626
commit
c79f221192
9 changed files with 69 additions and 31 deletions
|
@ -23,8 +23,6 @@ from synapse.push.action_generator import ActionGenerator
|
||||||
|
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
|
|
||||||
from synapse.events.utils import serialize_event
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -256,9 +254,9 @@ class BaseHandler(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
action_generator = ActionGenerator(self.store)
|
action_generator = ActionGenerator(self.store)
|
||||||
yield action_generator.handle_push_actions_for_event(serialize_event(
|
yield action_generator.handle_push_actions_for_event(
|
||||||
event, self.clock.time_msec()
|
event, self
|
||||||
))
|
)
|
||||||
|
|
||||||
destinations = set(extra_destinations)
|
destinations = set(extra_destinations)
|
||||||
for k, s in context.current_state.items():
|
for k, s in context.current_state.items():
|
||||||
|
|
|
@ -32,7 +32,7 @@ from synapse.crypto.event_signing import (
|
||||||
)
|
)
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
from synapse.events.utils import prune_event, serialize_event
|
from synapse.events.utils import prune_event
|
||||||
|
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
|
@ -246,8 +246,8 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
if not backfilled and not event.internal_metadata.is_outlier():
|
if not backfilled and not event.internal_metadata.is_outlier():
|
||||||
action_generator = ActionGenerator(self.store)
|
action_generator = ActionGenerator(self.store)
|
||||||
yield action_generator.handle_push_actions_for_event(serialize_event(
|
yield action_generator.handle_push_actions_for_event(
|
||||||
event, self.clock.time_msec())
|
event, self
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -84,7 +84,8 @@ class RegistrationHandler(BaseHandler):
|
||||||
localpart=None,
|
localpart=None,
|
||||||
password=None,
|
password=None,
|
||||||
generate_token=True,
|
generate_token=True,
|
||||||
guest_access_token=None
|
guest_access_token=None,
|
||||||
|
make_guest=False
|
||||||
):
|
):
|
||||||
"""Registers a new client on the server.
|
"""Registers a new client on the server.
|
||||||
|
|
||||||
|
@ -118,6 +119,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
token=token,
|
token=token,
|
||||||
password_hash=password_hash,
|
password_hash=password_hash,
|
||||||
was_guest=guest_access_token is not None,
|
was_guest=guest_access_token is not None,
|
||||||
|
make_guest=make_guest
|
||||||
)
|
)
|
||||||
|
|
||||||
yield registered_user(self.distributor, user)
|
yield registered_user(self.distributor, user)
|
||||||
|
|
|
@ -33,12 +33,12 @@ class ActionGenerator:
|
||||||
# tag (ie. we just need all the users).
|
# tag (ie. we just need all the users).
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def handle_push_actions_for_event(self, event):
|
def handle_push_actions_for_event(self, event, handler):
|
||||||
bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id(
|
bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id(
|
||||||
event['room_id'], self.store
|
event.room_id, self.store
|
||||||
)
|
)
|
||||||
|
|
||||||
actions_by_user = bulk_evaluator.action_for_event_by_user(event)
|
actions_by_user = yield bulk_evaluator.action_for_event_by_user(event, handler)
|
||||||
|
|
||||||
yield self.store.set_push_actions_for_event_and_users(
|
yield self.store.set_push_actions_for_event_and_users(
|
||||||
event,
|
event,
|
||||||
|
|
|
@ -23,6 +23,8 @@ from synapse.types import UserID
|
||||||
import baserules
|
import baserules
|
||||||
from push_rule_evaluator import PushRuleEvaluator
|
from push_rule_evaluator import PushRuleEvaluator
|
||||||
|
|
||||||
|
from synapse.events.utils import serialize_event
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,7 +56,7 @@ def evaluator_for_room_id(room_id, store):
|
||||||
display_names[ev.state_key] = ev.content.get("displayname")
|
display_names[ev.state_key] = ev.content.get("displayname")
|
||||||
|
|
||||||
defer.returnValue(BulkPushRuleEvaluator(
|
defer.returnValue(BulkPushRuleEvaluator(
|
||||||
room_id, rules_by_user, display_names, users
|
room_id, rules_by_user, display_names, users, store
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
|
@ -67,13 +69,15 @@ class BulkPushRuleEvaluator:
|
||||||
the same logic to run the actual rules, but could be optimised further
|
the same logic to run the actual rules, but could be optimised further
|
||||||
(see https://matrix.org/jira/browse/SYN-562)
|
(see https://matrix.org/jira/browse/SYN-562)
|
||||||
"""
|
"""
|
||||||
def __init__(self, room_id, rules_by_user, display_names, users_in_room):
|
def __init__(self, room_id, rules_by_user, display_names, users_in_room, store):
|
||||||
self.room_id = room_id
|
self.room_id = room_id
|
||||||
self.rules_by_user = rules_by_user
|
self.rules_by_user = rules_by_user
|
||||||
self.display_names = display_names
|
self.display_names = display_names
|
||||||
self.users_in_room = users_in_room
|
self.users_in_room = users_in_room
|
||||||
|
self.store = store
|
||||||
|
|
||||||
def action_for_event_by_user(self, event):
|
@defer.inlineCallbacks
|
||||||
|
def action_for_event_by_user(self, event, handler):
|
||||||
actions_by_user = {}
|
actions_by_user = {}
|
||||||
|
|
||||||
for uid, rules in self.rules_by_user.items():
|
for uid, rules in self.rules_by_user.items():
|
||||||
|
@ -81,6 +85,13 @@ class BulkPushRuleEvaluator:
|
||||||
if uid in self.display_names:
|
if uid in self.display_names:
|
||||||
display_name = self.display_names[uid]
|
display_name = self.display_names[uid]
|
||||||
|
|
||||||
|
is_guest = yield self.store.is_guest(UserID.from_string(uid))
|
||||||
|
filtered = yield handler._filter_events_for_client(
|
||||||
|
uid, [event], is_guest=is_guest
|
||||||
|
)
|
||||||
|
if len(filtered) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
for rule in rules:
|
for rule in rules:
|
||||||
if 'enabled' in rule and not rule['enabled']:
|
if 'enabled' in rule and not rule['enabled']:
|
||||||
continue
|
continue
|
||||||
|
@ -94,14 +105,20 @@ class BulkPushRuleEvaluator:
|
||||||
if len(actions) > 0:
|
if len(actions) > 0:
|
||||||
actions_by_user[uid] = actions
|
actions_by_user[uid] = actions
|
||||||
break
|
break
|
||||||
return actions_by_user
|
defer.returnValue(actions_by_user)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def event_matches_rule(event, rule,
|
def event_matches_rule(event, rule,
|
||||||
display_name, room_member_count, profile_tag):
|
display_name, room_member_count, profile_tag):
|
||||||
matches = True
|
matches = True
|
||||||
|
|
||||||
|
# passing the clock all the way into here is extremely awkward and push
|
||||||
|
# rules do not care about any of the relative timestamps, so we just
|
||||||
|
# pass 0 for the current time.
|
||||||
|
client_event = serialize_event(event, 0)
|
||||||
|
|
||||||
for cond in rule['conditions']:
|
for cond in rule['conditions']:
|
||||||
matches &= PushRuleEvaluator._event_fulfills_condition(
|
matches &= PushRuleEvaluator._event_fulfills_condition(
|
||||||
event, cond, display_name, room_member_count, profile_tag
|
client_event, cond, display_name, room_member_count, profile_tag
|
||||||
)
|
)
|
||||||
return matches
|
return matches
|
||||||
|
|
|
@ -259,7 +259,10 @@ class RegisterRestServlet(RestServlet):
|
||||||
def _do_guest_registration(self):
|
def _do_guest_registration(self):
|
||||||
if not self.hs.config.allow_guest_access:
|
if not self.hs.config.allow_guest_access:
|
||||||
defer.returnValue((403, "Guest access is disabled"))
|
defer.returnValue((403, "Guest access is disabled"))
|
||||||
user_id, _ = yield self.registration_handler.register(generate_token=False)
|
user_id, _ = yield self.registration_handler.register(
|
||||||
|
generate_token=False,
|
||||||
|
make_guest=True
|
||||||
|
)
|
||||||
access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
|
access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
|
||||||
defer.returnValue((200, {
|
defer.returnValue((200, {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
|
|
@ -32,8 +32,8 @@ class EventPushActionsStore(SQLBaseStore):
|
||||||
values = []
|
values = []
|
||||||
for uid, profile_tag, actions in tuples:
|
for uid, profile_tag, actions in tuples:
|
||||||
values.append({
|
values.append({
|
||||||
'room_id': event['room_id'],
|
'room_id': event.room_id,
|
||||||
'event_id': event['event_id'],
|
'event_id': event.event_id,
|
||||||
'user_id': uid,
|
'user_id': uid,
|
||||||
'profile_tag': profile_tag,
|
'profile_tag': profile_tag,
|
||||||
'actions': json.dumps(actions)
|
'actions': json.dumps(actions)
|
||||||
|
|
|
@ -18,7 +18,7 @@ from twisted.internet import defer
|
||||||
from synapse.api.errors import StoreError, Codes
|
from synapse.api.errors import StoreError, Codes
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||||
|
|
||||||
|
|
||||||
class RegistrationStore(SQLBaseStore):
|
class RegistrationStore(SQLBaseStore):
|
||||||
|
@ -73,7 +73,8 @@ class RegistrationStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def register(self, user_id, token, password_hash, was_guest=False):
|
def register(self, user_id, token, password_hash,
|
||||||
|
was_guest=False, make_guest=False):
|
||||||
"""Attempts to register an account.
|
"""Attempts to register an account.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -82,15 +83,18 @@ class RegistrationStore(SQLBaseStore):
|
||||||
password_hash (str): Optional. The password hash for this user.
|
password_hash (str): Optional. The password hash for this user.
|
||||||
was_guest (bool): Optional. Whether this is a guest account being
|
was_guest (bool): Optional. Whether this is a guest account being
|
||||||
upgraded to a non-guest account.
|
upgraded to a non-guest account.
|
||||||
|
make_guest (boolean): True if the the new user should be guest,
|
||||||
|
false to add a regular user account.
|
||||||
Raises:
|
Raises:
|
||||||
StoreError if the user_id could not be registered.
|
StoreError if the user_id could not be registered.
|
||||||
"""
|
"""
|
||||||
yield self.runInteraction(
|
yield self.runInteraction(
|
||||||
"register",
|
"register",
|
||||||
self._register, user_id, token, password_hash, was_guest
|
self._register, user_id, token, password_hash, was_guest, make_guest
|
||||||
)
|
)
|
||||||
|
self.is_guest.invalidate((user_id,))
|
||||||
|
|
||||||
def _register(self, txn, user_id, token, password_hash, was_guest):
|
def _register(self, txn, user_id, token, password_hash, was_guest, make_guest):
|
||||||
now = int(self.clock.time())
|
now = int(self.clock.time())
|
||||||
|
|
||||||
next_id = self._access_tokens_id_gen.get_next_txn(txn)
|
next_id = self._access_tokens_id_gen.get_next_txn(txn)
|
||||||
|
@ -100,12 +104,14 @@ class RegistrationStore(SQLBaseStore):
|
||||||
txn.execute("UPDATE users SET"
|
txn.execute("UPDATE users SET"
|
||||||
" password_hash = ?,"
|
" password_hash = ?,"
|
||||||
" upgrade_ts = ?"
|
" upgrade_ts = ?"
|
||||||
|
" is_guest = ?"
|
||||||
" WHERE name = ?",
|
" WHERE name = ?",
|
||||||
[password_hash, now, user_id])
|
[password_hash, now, make_guest, user_id])
|
||||||
else:
|
else:
|
||||||
txn.execute("INSERT INTO users(name, password_hash, creation_ts) "
|
txn.execute("INSERT INTO users "
|
||||||
"VALUES (?,?,?)",
|
"(name, password_hash, creation_ts, is_guest) "
|
||||||
[user_id, password_hash, now])
|
"VALUES (?,?,?,?)",
|
||||||
|
[user_id, password_hash, now, make_guest])
|
||||||
except self.database_engine.module.IntegrityError:
|
except self.database_engine.module.IntegrityError:
|
||||||
raise StoreError(
|
raise StoreError(
|
||||||
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
||||||
|
@ -126,7 +132,7 @@ class RegistrationStore(SQLBaseStore):
|
||||||
keyvalues={
|
keyvalues={
|
||||||
"name": user_id,
|
"name": user_id,
|
||||||
},
|
},
|
||||||
retcols=["name", "password_hash"],
|
retcols=["name", "password_hash", "is_guest"],
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -136,7 +142,7 @@ class RegistrationStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
def f(txn):
|
def f(txn):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT name, password_hash FROM users"
|
"SELECT name, password_hash, is_guest FROM users"
|
||||||
" WHERE lower(name) = lower(?)"
|
" WHERE lower(name) = lower(?)"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (user_id,))
|
txn.execute(sql, (user_id,))
|
||||||
|
@ -249,9 +255,21 @@ class RegistrationStore(SQLBaseStore):
|
||||||
|
|
||||||
defer.returnValue(res if res else False)
|
defer.returnValue(res if res else False)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks()
|
||||||
|
def is_guest(self, user):
|
||||||
|
res = yield self._simple_select_one_onecol(
|
||||||
|
table="users",
|
||||||
|
keyvalues={"name": user.to_string()},
|
||||||
|
retcol="is_guest",
|
||||||
|
allow_none=True,
|
||||||
|
desc="is_guest",
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(res if res else False)
|
||||||
|
|
||||||
def _query_for_auth(self, txn, token):
|
def _query_for_auth(self, txn, token):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT users.name, access_tokens.id as token_id"
|
"SELECT users.name, users.is_guest, access_tokens.id as token_id"
|
||||||
" FROM users"
|
" FROM users"
|
||||||
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
|
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
|
||||||
" WHERE token = ?"
|
" WHERE token = ?"
|
||||||
|
|
Loading…
Reference in a new issue