mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-04 21:58:54 +01:00
Merge branch 'develop' into markjh/config_cleanup
Conflicts: synapse/config/captcha.py
This commit is contained in:
commit
2d4d2bbae4
16 changed files with 176 additions and 88 deletions
|
@ -35,6 +35,7 @@ from twisted.enterprise import adbapi
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
from twisted.web.static import File
|
from twisted.web.static import File
|
||||||
from twisted.web.server import Site
|
from twisted.web.server import Site
|
||||||
|
from twisted.web.http import proxiedLogFormatter, combinedLogFormatter
|
||||||
from synapse.http.server import JsonResource, RootRedirect
|
from synapse.http.server import JsonResource, RootRedirect
|
||||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||||
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
||||||
|
@ -228,7 +229,11 @@ class SynapseHomeServer(HomeServer):
|
||||||
if not config.no_tls and config.bind_port is not None:
|
if not config.no_tls and config.bind_port is not None:
|
||||||
reactor.listenSSL(
|
reactor.listenSSL(
|
||||||
config.bind_port,
|
config.bind_port,
|
||||||
Site(self.root_resource),
|
SynapseSite(
|
||||||
|
"synapse.access.https",
|
||||||
|
config,
|
||||||
|
self.root_resource,
|
||||||
|
),
|
||||||
self.tls_context_factory,
|
self.tls_context_factory,
|
||||||
interface=config.bind_host
|
interface=config.bind_host
|
||||||
)
|
)
|
||||||
|
@ -237,7 +242,11 @@ class SynapseHomeServer(HomeServer):
|
||||||
if config.unsecure_port is not None:
|
if config.unsecure_port is not None:
|
||||||
reactor.listenTCP(
|
reactor.listenTCP(
|
||||||
config.unsecure_port,
|
config.unsecure_port,
|
||||||
Site(self.root_resource),
|
SynapseSite(
|
||||||
|
"synapse.access.http",
|
||||||
|
config,
|
||||||
|
self.root_resource,
|
||||||
|
),
|
||||||
interface=config.bind_host
|
interface=config.bind_host
|
||||||
)
|
)
|
||||||
logger.info("Synapse now listening on port %d", config.unsecure_port)
|
logger.info("Synapse now listening on port %d", config.unsecure_port)
|
||||||
|
@ -245,7 +254,13 @@ class SynapseHomeServer(HomeServer):
|
||||||
metrics_resource = self.get_resource_for_metrics()
|
metrics_resource = self.get_resource_for_metrics()
|
||||||
if metrics_resource and config.metrics_port is not None:
|
if metrics_resource and config.metrics_port is not None:
|
||||||
reactor.listenTCP(
|
reactor.listenTCP(
|
||||||
config.metrics_port, Site(metrics_resource), interface="127.0.0.1",
|
config.metrics_port,
|
||||||
|
SynapseSite(
|
||||||
|
"synapse.access.metrics",
|
||||||
|
config,
|
||||||
|
metrics_resource,
|
||||||
|
),
|
||||||
|
interface="127.0.0.1",
|
||||||
)
|
)
|
||||||
logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port)
|
logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port)
|
||||||
|
|
||||||
|
@ -462,6 +477,24 @@ class SynapseService(service.Service):
|
||||||
return self._port.stopListening()
|
return self._port.stopListening()
|
||||||
|
|
||||||
|
|
||||||
|
class SynapseSite(Site):
|
||||||
|
"""
|
||||||
|
Subclass of a twisted http Site that does access logging with python's
|
||||||
|
standard logging
|
||||||
|
"""
|
||||||
|
def __init__(self, logger_name, config, resource, *args, **kwargs):
|
||||||
|
Site.__init__(self, resource, *args, **kwargs)
|
||||||
|
if config.captcha_ip_origin_is_x_forwarded:
|
||||||
|
self._log_formatter = proxiedLogFormatter
|
||||||
|
else:
|
||||||
|
self._log_formatter = combinedLogFormatter
|
||||||
|
self.access_logger = logging.getLogger(logger_name)
|
||||||
|
|
||||||
|
def log(self, request):
|
||||||
|
line = self._log_formatter(self._logDateTime, request)
|
||||||
|
self.access_logger.info(line)
|
||||||
|
|
||||||
|
|
||||||
def run(hs):
|
def run(hs):
|
||||||
|
|
||||||
def in_thread():
|
def in_thread():
|
||||||
|
|
|
@ -21,6 +21,7 @@ class CaptchaConfig(Config):
|
||||||
self.recaptcha_private_key = config["recaptcha_private_key"]
|
self.recaptcha_private_key = config["recaptcha_private_key"]
|
||||||
self.recaptcha_public_key = config["recaptcha_public_key"]
|
self.recaptcha_public_key = config["recaptcha_public_key"]
|
||||||
self.enable_registration_captcha = config["enable_registration_captcha"]
|
self.enable_registration_captcha = config["enable_registration_captcha"]
|
||||||
|
# XXX: This is used for more than just captcha
|
||||||
self.captcha_ip_origin_is_x_forwarded = (
|
self.captcha_ip_origin_is_x_forwarded = (
|
||||||
config["captcha_ip_origin_is_x_forwarded"]
|
config["captcha_ip_origin_is_x_forwarded"]
|
||||||
)
|
)
|
||||||
|
|
|
@ -87,14 +87,29 @@ class IdentityHandler(BaseHandler):
|
||||||
logger.debug("binding threepid %r to %s", creds, mxid)
|
logger.debug("binding threepid %r to %s", creds, mxid)
|
||||||
http_client = SimpleHttpClient(self.hs)
|
http_client = SimpleHttpClient(self.hs)
|
||||||
data = None
|
data = None
|
||||||
|
|
||||||
|
if 'id_server' in creds:
|
||||||
|
id_server = creds['id_server']
|
||||||
|
elif 'idServer' in creds:
|
||||||
|
id_server = creds['idServer']
|
||||||
|
else:
|
||||||
|
raise SynapseError(400, "No id_server in creds")
|
||||||
|
|
||||||
|
if 'client_secret' in creds:
|
||||||
|
client_secret = creds['client_secret']
|
||||||
|
elif 'clientSecret' in creds:
|
||||||
|
client_secret = creds['clientSecret']
|
||||||
|
else:
|
||||||
|
raise SynapseError(400, "No client_secret in creds")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = yield http_client.post_urlencoded_get_json(
|
data = yield http_client.post_urlencoded_get_json(
|
||||||
"https://%s%s" % (
|
"https://%s%s" % (
|
||||||
creds['id_server'], "/_matrix/identity/api/v1/3pid/bind"
|
id_server, "/_matrix/identity/api/v1/3pid/bind"
|
||||||
),
|
),
|
||||||
{
|
{
|
||||||
'sid': creds['sid'],
|
'sid': creds['sid'],
|
||||||
'client_secret': creds['client_secret'],
|
'client_secret': client_secret,
|
||||||
'mxid': mxid,
|
'mxid': mxid,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -267,8 +267,7 @@ class MessageHandler(BaseHandler):
|
||||||
user, pagination_config.get_source_config("presence"), None
|
user, pagination_config.get_source_config("presence"), None
|
||||||
)
|
)
|
||||||
|
|
||||||
public_rooms = yield self.store.get_rooms(is_public=True)
|
public_room_ids = yield self.store.get_public_room_ids()
|
||||||
public_room_ids = [r["room_id"] for r in public_rooms]
|
|
||||||
|
|
||||||
limit = pagin_config.limit
|
limit = pagin_config.limit
|
||||||
if limit is None:
|
if limit is None:
|
||||||
|
|
|
@ -881,7 +881,7 @@ class PresenceEventSource(object):
|
||||||
# TODO(paul): limit
|
# TODO(paul): limit
|
||||||
|
|
||||||
for serial, user_ids in presence._remote_offline_serials:
|
for serial, user_ids in presence._remote_offline_serials:
|
||||||
if serial < from_key:
|
if serial <= from_key:
|
||||||
break
|
break
|
||||||
|
|
||||||
if serial > max_serial:
|
if serial > max_serial:
|
||||||
|
|
|
@ -24,7 +24,7 @@ from syutil.jsonutil import (
|
||||||
encode_canonical_json, encode_pretty_printed_json
|
encode_canonical_json, encode_pretty_printed_json
|
||||||
)
|
)
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer
|
||||||
from twisted.web import server, resource
|
from twisted.web import server, resource
|
||||||
from twisted.web.server import NOT_DONE_YET
|
from twisted.web.server import NOT_DONE_YET
|
||||||
from twisted.web.util import redirectTo
|
from twisted.web.util import redirectTo
|
||||||
|
@ -179,19 +179,6 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
self._PathEntry(path_pattern, callback)
|
self._PathEntry(path_pattern, callback)
|
||||||
)
|
)
|
||||||
|
|
||||||
def start_listening(self, port):
|
|
||||||
""" Registers the http server with the twisted reactor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
port (int): The port to listen on.
|
|
||||||
|
|
||||||
"""
|
|
||||||
reactor.listenTCP(
|
|
||||||
port,
|
|
||||||
server.Site(self),
|
|
||||||
interface=self.hs.config.bind_host
|
|
||||||
)
|
|
||||||
|
|
||||||
def render(self, request):
|
def render(self, request):
|
||||||
""" This gets called by twisted every time someone sends us a request.
|
""" This gets called by twisted every time someone sends us a request.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -170,7 +170,11 @@ class PusherPool:
|
||||||
def _start_pushers(self, pushers):
|
def _start_pushers(self, pushers):
|
||||||
logger.info("Starting %d pushers", len(pushers))
|
logger.info("Starting %d pushers", len(pushers))
|
||||||
for pusherdict in pushers:
|
for pusherdict in pushers:
|
||||||
p = self._create_pusher(pusherdict)
|
try:
|
||||||
|
p = self._create_pusher(pusherdict)
|
||||||
|
except PusherConfigException:
|
||||||
|
logger.exception("Couldn't start a pusher: caught PusherConfigException")
|
||||||
|
continue
|
||||||
if p:
|
if p:
|
||||||
fullid = "%s:%s:%s" % (
|
fullid = "%s:%s:%s" % (
|
||||||
pusherdict['app_id'],
|
pusherdict['app_id'],
|
||||||
|
@ -182,6 +186,8 @@ class PusherPool:
|
||||||
self.pushers[fullid] = p
|
self.pushers[fullid] = p
|
||||||
p.start()
|
p.start()
|
||||||
|
|
||||||
|
logger.info("Started pushers")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def remove_pusher(self, app_id, pushkey, user_name):
|
def remove_pusher(self, app_id, pushkey, user_name):
|
||||||
fullid = "%s:%s:%s" % (app_id, pushkey, user_name)
|
fullid = "%s:%s:%s" % (app_id, pushkey, user_name)
|
||||||
|
|
|
@ -355,11 +355,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
||||||
# being sent)
|
# being sent)
|
||||||
last_txn_id = self._get_last_txn(txn, service.id)
|
last_txn_id = self._get_last_txn(txn, service.id)
|
||||||
|
|
||||||
result = txn.execute(
|
txn.execute(
|
||||||
"SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
|
"SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
|
||||||
(service.id,)
|
(service.id,)
|
||||||
)
|
)
|
||||||
highest_txn_id = result.fetchone()[0]
|
highest_txn_id = txn.fetchone()[0]
|
||||||
if highest_txn_id is None:
|
if highest_txn_id is None:
|
||||||
highest_txn_id = 0
|
highest_txn_id = 0
|
||||||
|
|
||||||
|
@ -441,15 +441,17 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
||||||
def _get_oldest_unsent_txn(self, txn, service):
|
def _get_oldest_unsent_txn(self, txn, service):
|
||||||
# Monotonically increasing txn ids, so just select the smallest
|
# Monotonically increasing txn ids, so just select the smallest
|
||||||
# one in the txns table (we delete them when they are sent)
|
# one in the txns table (we delete them when they are sent)
|
||||||
result = txn.execute(
|
txn.execute(
|
||||||
"SELECT MIN(txn_id), * FROM application_services_txns WHERE as_id=?",
|
"SELECT * FROM application_services_txns WHERE as_id=?"
|
||||||
|
" ORDER BY txn_id ASC LIMIT 1",
|
||||||
(service.id,)
|
(service.id,)
|
||||||
)
|
)
|
||||||
entry = self.cursor_to_dict(result)[0]
|
rows = self.cursor_to_dict(txn)
|
||||||
if not entry or entry["txn_id"] is None:
|
if not rows:
|
||||||
# the min(txn_id) part will force a row, so entry may not be None
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
entry = rows[0]
|
||||||
|
|
||||||
event_ids = json.loads(entry["event_ids"])
|
event_ids = json.loads(entry["event_ids"])
|
||||||
events = self._get_events_txn(txn, event_ids)
|
events = self._get_events_txn(txn, event_ids)
|
||||||
|
|
||||||
|
@ -458,11 +460,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_last_txn(self, txn, service_id):
|
def _get_last_txn(self, txn, service_id):
|
||||||
result = txn.execute(
|
txn.execute(
|
||||||
"SELECT last_txn FROM application_services_state WHERE as_id=?",
|
"SELECT last_txn FROM application_services_state WHERE as_id=?",
|
||||||
(service_id,)
|
(service_id,)
|
||||||
)
|
)
|
||||||
last_txn_id = result.fetchone()
|
last_txn_id = txn.fetchone()
|
||||||
if last_txn_id is None or last_txn_id[0] is None: # no row exists
|
if last_txn_id is None or last_txn_id[0] is None: # no row exists
|
||||||
return 0
|
return 0
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -137,8 +137,13 @@ class KeyStore(SQLBaseStore):
|
||||||
ts_valid_until_ms (int): The time when this json stops being valid.
|
ts_valid_until_ms (int): The time when this json stops being valid.
|
||||||
key_json (bytes): The encoded JSON.
|
key_json (bytes): The encoded JSON.
|
||||||
"""
|
"""
|
||||||
return self._simple_insert(
|
return self._simple_upsert(
|
||||||
table="server_keys_json",
|
table="server_keys_json",
|
||||||
|
keyvalues={
|
||||||
|
"server_name": server_name,
|
||||||
|
"key_id": key_id,
|
||||||
|
"from_server": from_server,
|
||||||
|
},
|
||||||
values={
|
values={
|
||||||
"server_name": server_name,
|
"server_name": server_name,
|
||||||
"key_id": key_id,
|
"key_id": key_id,
|
||||||
|
@ -147,7 +152,6 @@ class KeyStore(SQLBaseStore):
|
||||||
"ts_valid_until_ms": ts_expires_ms,
|
"ts_valid_until_ms": ts_expires_ms,
|
||||||
"key_json": buffer(key_json_bytes),
|
"key_json": buffer(key_json_bytes),
|
||||||
},
|
},
|
||||||
or_replace=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_server_keys_json(self, server_keys):
|
def get_server_keys_json(self, server_keys):
|
||||||
|
|
|
@ -21,34 +21,62 @@ from synapse.api.errors import StoreError
|
||||||
from syutil.jsonutil import encode_canonical_json
|
from syutil.jsonutil import encode_canonical_json
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import simplejson as json
|
||||||
|
import types
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PusherStore(SQLBaseStore):
|
class PusherStore(SQLBaseStore):
|
||||||
|
def _decode_pushers_rows(self, rows):
|
||||||
|
for r in rows:
|
||||||
|
dataJson = r['data']
|
||||||
|
r['data'] = None
|
||||||
|
try:
|
||||||
|
if isinstance(dataJson, types.BufferType):
|
||||||
|
dataJson = str(dataJson).decode("UTF8")
|
||||||
|
|
||||||
|
r['data'] = json.loads(dataJson)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warn(
|
||||||
|
"Invalid JSON in data for pusher %d: %s, %s",
|
||||||
|
r['id'], dataJson, e.message,
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(r['pushkey'], types.BufferType):
|
||||||
|
r['pushkey'] = str(r['pushkey']).decode("UTF8")
|
||||||
|
|
||||||
|
return rows
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
|
def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
|
||||||
sql = (
|
def r(txn):
|
||||||
"SELECT * FROM pushers "
|
sql = (
|
||||||
"WHERE app_id = ? AND pushkey = ?"
|
"SELECT * FROM pushers"
|
||||||
)
|
" WHERE app_id = ? AND pushkey = ?"
|
||||||
|
)
|
||||||
|
|
||||||
rows = yield self._execute_and_decode(
|
txn.execute(sql, (app_id, pushkey,))
|
||||||
"get_pushers_by_app_id_and_pushkey",
|
rows = self.cursor_to_dict(txn)
|
||||||
sql,
|
|
||||||
app_id, pushkey
|
return self._decode_pushers_rows(rows)
|
||||||
|
|
||||||
|
rows = yield self.runInteraction(
|
||||||
|
"get_pushers_by_app_id_and_pushkey", r
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(rows)
|
defer.returnValue(rows)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_all_pushers(self):
|
def get_all_pushers(self):
|
||||||
sql = (
|
def get_pushers(txn):
|
||||||
"SELECT * FROM pushers"
|
txn.execute("SELECT * FROM pushers")
|
||||||
)
|
rows = self.cursor_to_dict(txn)
|
||||||
|
|
||||||
rows = yield self._execute_and_decode("get_all_pushers", sql)
|
return self._decode_pushers_rows(rows)
|
||||||
|
|
||||||
|
rows = yield self.runInteraction("get_all_pushers", get_pushers)
|
||||||
defer.returnValue(rows)
|
defer.returnValue(rows)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -72,7 +100,7 @@ class PusherStore(SQLBaseStore):
|
||||||
device_display_name=device_display_name,
|
device_display_name=device_display_name,
|
||||||
ts=pushkey_ts,
|
ts=pushkey_ts,
|
||||||
lang=lang,
|
lang=lang,
|
||||||
data=encode_canonical_json(data).decode("UTF-8"),
|
data=encode_canonical_json(data),
|
||||||
),
|
),
|
||||||
insertion_values=dict(
|
insertion_values=dict(
|
||||||
id=next_id,
|
id=next_id,
|
||||||
|
|
|
@ -181,7 +181,7 @@ class RegistrationStore(SQLBaseStore):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
|
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
|
||||||
yield self._simple_upsert("user_threepids", {
|
yield self._simple_upsert("user_threepids", {
|
||||||
"user": user_id,
|
"user_id": user_id,
|
||||||
"medium": medium,
|
"medium": medium,
|
||||||
"address": address,
|
"address": address,
|
||||||
}, {
|
}, {
|
||||||
|
@ -193,7 +193,7 @@ class RegistrationStore(SQLBaseStore):
|
||||||
def user_get_threepids(self, user_id):
|
def user_get_threepids(self, user_id):
|
||||||
ret = yield self._simple_select_list(
|
ret = yield self._simple_select_list(
|
||||||
"user_threepids", {
|
"user_threepids", {
|
||||||
"user": user_id
|
"user_id": user_id
|
||||||
},
|
},
|
||||||
['medium', 'address', 'validated_at', 'added_at'],
|
['medium', 'address', 'validated_at', 'added_at'],
|
||||||
'user_get_threepids'
|
'user_get_threepids'
|
||||||
|
|
|
@ -75,6 +75,16 @@ class RoomStore(SQLBaseStore):
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_public_room_ids(self):
|
||||||
|
return self._simple_select_onecol(
|
||||||
|
table="rooms",
|
||||||
|
keyvalues={
|
||||||
|
"is_public": True,
|
||||||
|
},
|
||||||
|
retcol="room_id",
|
||||||
|
desc="get_public_room_ids",
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_rooms(self, is_public):
|
def get_rooms(self, is_public):
|
||||||
"""Retrieve a list of all public rooms.
|
"""Retrieve a list of all public rooms.
|
||||||
|
@ -186,14 +196,13 @@ class RoomStore(SQLBaseStore):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
|
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
|
||||||
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
|
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
|
||||||
"INNER JOIN state_events as s ON e.event_id = s.event_id "
|
|
||||||
"WHERE c.room_id = ? "
|
"WHERE c.room_id = ? "
|
||||||
) % {
|
) % {
|
||||||
"redacted": del_sql,
|
"redacted": del_sql,
|
||||||
}
|
}
|
||||||
|
|
||||||
sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')"
|
sql += " AND ((c.type = 'm.room.name' AND c.state_key = '')"
|
||||||
sql += " OR s.type = 'm.room.aliases')"
|
sql += " OR c.type = 'm.room.aliases')"
|
||||||
args = (room_id,)
|
args = (room_id,)
|
||||||
|
|
||||||
results = yield self._execute_and_decode("get_current_state", sql, *args)
|
results = yield self._execute_and_decode("get_current_state", sql, *args)
|
||||||
|
|
|
@ -65,6 +65,7 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_rooms_for_user.invalidate(target_user_id)
|
self.get_rooms_for_user.invalidate(target_user_id)
|
||||||
|
self.get_joined_hosts_for_room.invalidate(event.room_id)
|
||||||
|
|
||||||
def get_room_member(self, user_id, room_id):
|
def get_room_member(self, user_id, room_id):
|
||||||
"""Retrieve the current state of a room member.
|
"""Retrieve the current state of a room member.
|
||||||
|
@ -162,6 +163,7 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
|
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@cached()
|
||||||
def get_joined_hosts_for_room(self, room_id):
|
def get_joined_hosts_for_room(self, room_id):
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"get_joined_hosts_for_room",
|
"get_joined_hosts_for_room",
|
||||||
|
|
9
synapse/storage/schema/delta/17/user_threepids.sql
Normal file
9
synapse/storage/schema/delta/17/user_threepids.sql
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
CREATE TABLE user_threepids (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
medium TEXT NOT NULL,
|
||||||
|
address TEXT NOT NULL,
|
||||||
|
validated_at BIGINT NOT NULL,
|
||||||
|
added_at BIGINT NOT NULL,
|
||||||
|
CONSTRAINT user_medium_address UNIQUE (user_id, medium, address)
|
||||||
|
);
|
||||||
|
CREATE INDEX user_threepids_user_id ON user_threepids(user_id);
|
|
@ -128,25 +128,18 @@ class StateStore(SQLBaseStore):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_current_state(self, room_id, event_type=None, state_key=""):
|
def get_current_state(self, room_id, event_type=None, state_key=""):
|
||||||
del_sql = (
|
sql = (
|
||||||
"SELECT event_id FROM redactions WHERE redacts = e.event_id "
|
"SELECT e.*, r.event_id FROM events as e"
|
||||||
"LIMIT 1"
|
" LEFT JOIN redactions as r ON r.redacts = e.event_id"
|
||||||
|
" INNER JOIN current_state_events as c ON e.event_id = c.event_id"
|
||||||
|
" WHERE c.room_id = ? "
|
||||||
)
|
)
|
||||||
|
|
||||||
sql = (
|
|
||||||
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
|
|
||||||
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
|
|
||||||
"INNER JOIN state_events as s ON e.event_id = s.event_id "
|
|
||||||
"WHERE c.room_id = ? "
|
|
||||||
) % {
|
|
||||||
"redacted": del_sql,
|
|
||||||
}
|
|
||||||
|
|
||||||
if event_type and state_key is not None:
|
if event_type and state_key is not None:
|
||||||
sql += " AND s.type = ? AND s.state_key = ? "
|
sql += " AND c.type = ? AND c.state_key = ? "
|
||||||
args = (room_id, event_type, state_key)
|
args = (room_id, event_type, state_key)
|
||||||
elif event_type:
|
elif event_type:
|
||||||
sql += " AND s.type = ?"
|
sql += " AND c.type = ?"
|
||||||
args = (room_id, event_type)
|
args = (room_id, event_type)
|
||||||
else:
|
else:
|
||||||
args = (room_id, )
|
args = (room_id, )
|
||||||
|
|
|
@ -30,15 +30,13 @@ class IdGenerator(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_next(self):
|
def get_next(self):
|
||||||
|
if self._next_id is None:
|
||||||
|
yield self.store.runInteraction(
|
||||||
|
"IdGenerator_%s" % (self.table,),
|
||||||
|
self.get_next_txn,
|
||||||
|
)
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if not self._next_id:
|
|
||||||
res = yield self.store._execute_and_decode(
|
|
||||||
"IdGenerator_%s" % (self.table,),
|
|
||||||
"SELECT MAX(%s) as mx FROM %s" % (self.column, self.table,)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._next_id = (res and res[0] and res[0]["mx"]) or 1
|
|
||||||
|
|
||||||
i = self._next_id
|
i = self._next_id
|
||||||
self._next_id += 1
|
self._next_id += 1
|
||||||
defer.returnValue(i)
|
defer.returnValue(i)
|
||||||
|
@ -86,10 +84,10 @@ class StreamIdGenerator(object):
|
||||||
with stream_id_gen.get_next_txn(txn) as stream_id:
|
with stream_id_gen.get_next_txn(txn) as stream_id:
|
||||||
# ... persist event ...
|
# ... persist event ...
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
if not self._current_max:
|
||||||
if not self._current_max:
|
self._get_or_compute_current_max(txn)
|
||||||
self._compute_current_max(txn)
|
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
self._current_max += 1
|
self._current_max += 1
|
||||||
next_id = self._current_max
|
next_id = self._current_max
|
||||||
|
|
||||||
|
@ -110,22 +108,24 @@ class StreamIdGenerator(object):
|
||||||
"""Returns the maximum stream id such that all stream ids less than or
|
"""Returns the maximum stream id such that all stream ids less than or
|
||||||
equal to it have been successfully persisted.
|
equal to it have been successfully persisted.
|
||||||
"""
|
"""
|
||||||
|
if not self._current_max:
|
||||||
|
yield store.runInteraction(
|
||||||
|
"_compute_current_max",
|
||||||
|
self._get_or_compute_current_max,
|
||||||
|
)
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._unfinished_ids:
|
if self._unfinished_ids:
|
||||||
defer.returnValue(self._unfinished_ids[0] - 1)
|
defer.returnValue(self._unfinished_ids[0] - 1)
|
||||||
|
|
||||||
if not self._current_max:
|
|
||||||
yield store.runInteraction(
|
|
||||||
"_compute_current_max",
|
|
||||||
self._compute_current_max,
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue(self._current_max)
|
defer.returnValue(self._current_max)
|
||||||
|
|
||||||
def _compute_current_max(self, txn):
|
def _get_or_compute_current_max(self, txn):
|
||||||
txn.execute("SELECT MAX(stream_ordering) FROM events")
|
with self._lock:
|
||||||
val, = txn.fetchone()
|
txn.execute("SELECT MAX(stream_ordering) FROM events")
|
||||||
|
rows = txn.fetchall()
|
||||||
|
val, = rows[0]
|
||||||
|
|
||||||
self._current_max = int(val) if val else 1
|
self._current_max = int(val) if val else 1
|
||||||
|
|
||||||
return self._current_max
|
return self._current_max
|
||||||
|
|
Loading…
Reference in a new issue