0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-04 21:58:54 +01:00

Merge branch 'develop' into markjh/config_cleanup

Conflicts:
	synapse/config/captcha.py
This commit is contained in:
Mark Haines 2015-04-30 16:54:55 +01:00
commit 2d4d2bbae4
16 changed files with 176 additions and 88 deletions

View file

@ -35,6 +35,7 @@ from twisted.enterprise import adbapi
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.static import File from twisted.web.static import File
from twisted.web.server import Site from twisted.web.server import Site
from twisted.web.http import proxiedLogFormatter, combinedLogFormatter
from synapse.http.server import JsonResource, RootRedirect from synapse.http.server import JsonResource, RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource
@ -228,7 +229,11 @@ class SynapseHomeServer(HomeServer):
if not config.no_tls and config.bind_port is not None: if not config.no_tls and config.bind_port is not None:
reactor.listenSSL( reactor.listenSSL(
config.bind_port, config.bind_port,
Site(self.root_resource), SynapseSite(
"synapse.access.https",
config,
self.root_resource,
),
self.tls_context_factory, self.tls_context_factory,
interface=config.bind_host interface=config.bind_host
) )
@ -237,7 +242,11 @@ class SynapseHomeServer(HomeServer):
if config.unsecure_port is not None: if config.unsecure_port is not None:
reactor.listenTCP( reactor.listenTCP(
config.unsecure_port, config.unsecure_port,
Site(self.root_resource), SynapseSite(
"synapse.access.http",
config,
self.root_resource,
),
interface=config.bind_host interface=config.bind_host
) )
logger.info("Synapse now listening on port %d", config.unsecure_port) logger.info("Synapse now listening on port %d", config.unsecure_port)
@ -245,7 +254,13 @@ class SynapseHomeServer(HomeServer):
metrics_resource = self.get_resource_for_metrics() metrics_resource = self.get_resource_for_metrics()
if metrics_resource and config.metrics_port is not None: if metrics_resource and config.metrics_port is not None:
reactor.listenTCP( reactor.listenTCP(
config.metrics_port, Site(metrics_resource), interface="127.0.0.1", config.metrics_port,
SynapseSite(
"synapse.access.metrics",
config,
metrics_resource,
),
interface="127.0.0.1",
) )
logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port) logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port)
@ -462,6 +477,24 @@ class SynapseService(service.Service):
return self._port.stopListening() return self._port.stopListening()
class SynapseSite(Site):
"""
Subclass of a twisted http Site that does access logging with python's
standard logging
"""
def __init__(self, logger_name, config, resource, *args, **kwargs):
Site.__init__(self, resource, *args, **kwargs)
if config.captcha_ip_origin_is_x_forwarded:
self._log_formatter = proxiedLogFormatter
else:
self._log_formatter = combinedLogFormatter
self.access_logger = logging.getLogger(logger_name)
def log(self, request):
line = self._log_formatter(self._logDateTime, request)
self.access_logger.info(line)
def run(hs): def run(hs):
def in_thread(): def in_thread():

View file

@ -21,6 +21,7 @@ class CaptchaConfig(Config):
self.recaptcha_private_key = config["recaptcha_private_key"] self.recaptcha_private_key = config["recaptcha_private_key"]
self.recaptcha_public_key = config["recaptcha_public_key"] self.recaptcha_public_key = config["recaptcha_public_key"]
self.enable_registration_captcha = config["enable_registration_captcha"] self.enable_registration_captcha = config["enable_registration_captcha"]
# XXX: This is used for more than just captcha
self.captcha_ip_origin_is_x_forwarded = ( self.captcha_ip_origin_is_x_forwarded = (
config["captcha_ip_origin_is_x_forwarded"] config["captcha_ip_origin_is_x_forwarded"]
) )

View file

@ -87,14 +87,29 @@ class IdentityHandler(BaseHandler):
logger.debug("binding threepid %r to %s", creds, mxid) logger.debug("binding threepid %r to %s", creds, mxid)
http_client = SimpleHttpClient(self.hs) http_client = SimpleHttpClient(self.hs)
data = None data = None
if 'id_server' in creds:
id_server = creds['id_server']
elif 'idServer' in creds:
id_server = creds['idServer']
else:
raise SynapseError(400, "No id_server in creds")
if 'client_secret' in creds:
client_secret = creds['client_secret']
elif 'clientSecret' in creds:
client_secret = creds['clientSecret']
else:
raise SynapseError(400, "No client_secret in creds")
try: try:
data = yield http_client.post_urlencoded_get_json( data = yield http_client.post_urlencoded_get_json(
"https://%s%s" % ( "https://%s%s" % (
creds['id_server'], "/_matrix/identity/api/v1/3pid/bind" id_server, "/_matrix/identity/api/v1/3pid/bind"
), ),
{ {
'sid': creds['sid'], 'sid': creds['sid'],
'client_secret': creds['client_secret'], 'client_secret': client_secret,
'mxid': mxid, 'mxid': mxid,
} }
) )

View file

@ -267,8 +267,7 @@ class MessageHandler(BaseHandler):
user, pagination_config.get_source_config("presence"), None user, pagination_config.get_source_config("presence"), None
) )
public_rooms = yield self.store.get_rooms(is_public=True) public_room_ids = yield self.store.get_public_room_ids()
public_room_ids = [r["room_id"] for r in public_rooms]
limit = pagin_config.limit limit = pagin_config.limit
if limit is None: if limit is None:

View file

@ -881,7 +881,7 @@ class PresenceEventSource(object):
# TODO(paul): limit # TODO(paul): limit
for serial, user_ids in presence._remote_offline_serials: for serial, user_ids in presence._remote_offline_serials:
if serial < from_key: if serial <= from_key:
break break
if serial > max_serial: if serial > max_serial:

View file

@ -24,7 +24,7 @@ from syutil.jsonutil import (
encode_canonical_json, encode_pretty_printed_json encode_canonical_json, encode_pretty_printed_json
) )
from twisted.internet import defer, reactor from twisted.internet import defer
from twisted.web import server, resource from twisted.web import server, resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from twisted.web.util import redirectTo from twisted.web.util import redirectTo
@ -179,19 +179,6 @@ class JsonResource(HttpServer, resource.Resource):
self._PathEntry(path_pattern, callback) self._PathEntry(path_pattern, callback)
) )
def start_listening(self, port):
""" Registers the http server with the twisted reactor.
Args:
port (int): The port to listen on.
"""
reactor.listenTCP(
port,
server.Site(self),
interface=self.hs.config.bind_host
)
def render(self, request): def render(self, request):
""" This gets called by twisted every time someone sends us a request. """ This gets called by twisted every time someone sends us a request.
""" """

View file

@ -170,7 +170,11 @@ class PusherPool:
def _start_pushers(self, pushers): def _start_pushers(self, pushers):
logger.info("Starting %d pushers", len(pushers)) logger.info("Starting %d pushers", len(pushers))
for pusherdict in pushers: for pusherdict in pushers:
p = self._create_pusher(pusherdict) try:
p = self._create_pusher(pusherdict)
except PusherConfigException:
logger.exception("Couldn't start a pusher: caught PusherConfigException")
continue
if p: if p:
fullid = "%s:%s:%s" % ( fullid = "%s:%s:%s" % (
pusherdict['app_id'], pusherdict['app_id'],
@ -182,6 +186,8 @@ class PusherPool:
self.pushers[fullid] = p self.pushers[fullid] = p
p.start() p.start()
logger.info("Started pushers")
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey, user_name): def remove_pusher(self, app_id, pushkey, user_name):
fullid = "%s:%s:%s" % (app_id, pushkey, user_name) fullid = "%s:%s:%s" % (app_id, pushkey, user_name)

View file

@ -355,11 +355,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
# being sent) # being sent)
last_txn_id = self._get_last_txn(txn, service.id) last_txn_id = self._get_last_txn(txn, service.id)
result = txn.execute( txn.execute(
"SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?", "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
(service.id,) (service.id,)
) )
highest_txn_id = result.fetchone()[0] highest_txn_id = txn.fetchone()[0]
if highest_txn_id is None: if highest_txn_id is None:
highest_txn_id = 0 highest_txn_id = 0
@ -441,15 +441,17 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
def _get_oldest_unsent_txn(self, txn, service): def _get_oldest_unsent_txn(self, txn, service):
# Monotonically increasing txn ids, so just select the smallest # Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent) # one in the txns table (we delete them when they are sent)
result = txn.execute( txn.execute(
"SELECT MIN(txn_id), * FROM application_services_txns WHERE as_id=?", "SELECT * FROM application_services_txns WHERE as_id=?"
" ORDER BY txn_id ASC LIMIT 1",
(service.id,) (service.id,)
) )
entry = self.cursor_to_dict(result)[0] rows = self.cursor_to_dict(txn)
if not entry or entry["txn_id"] is None: if not rows:
# the min(txn_id) part will force a row, so entry may not be None
return None return None
entry = rows[0]
event_ids = json.loads(entry["event_ids"]) event_ids = json.loads(entry["event_ids"])
events = self._get_events_txn(txn, event_ids) events = self._get_events_txn(txn, event_ids)
@ -458,11 +460,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
) )
def _get_last_txn(self, txn, service_id): def _get_last_txn(self, txn, service_id):
result = txn.execute( txn.execute(
"SELECT last_txn FROM application_services_state WHERE as_id=?", "SELECT last_txn FROM application_services_state WHERE as_id=?",
(service_id,) (service_id,)
) )
last_txn_id = result.fetchone() last_txn_id = txn.fetchone()
if last_txn_id is None or last_txn_id[0] is None: # no row exists if last_txn_id is None or last_txn_id[0] is None: # no row exists
return 0 return 0
else: else:

View file

@ -137,8 +137,13 @@ class KeyStore(SQLBaseStore):
ts_valid_until_ms (int): The time when this json stops being valid. ts_valid_until_ms (int): The time when this json stops being valid.
key_json (bytes): The encoded JSON. key_json (bytes): The encoded JSON.
""" """
return self._simple_insert( return self._simple_upsert(
table="server_keys_json", table="server_keys_json",
keyvalues={
"server_name": server_name,
"key_id": key_id,
"from_server": from_server,
},
values={ values={
"server_name": server_name, "server_name": server_name,
"key_id": key_id, "key_id": key_id,
@ -147,7 +152,6 @@ class KeyStore(SQLBaseStore):
"ts_valid_until_ms": ts_expires_ms, "ts_valid_until_ms": ts_expires_ms,
"key_json": buffer(key_json_bytes), "key_json": buffer(key_json_bytes),
}, },
or_replace=True,
) )
def get_server_keys_json(self, server_keys): def get_server_keys_json(self, server_keys):

View file

@ -21,34 +21,62 @@ from synapse.api.errors import StoreError
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
import logging import logging
import simplejson as json
import types
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PusherStore(SQLBaseStore): class PusherStore(SQLBaseStore):
def _decode_pushers_rows(self, rows):
for r in rows:
dataJson = r['data']
r['data'] = None
try:
if isinstance(dataJson, types.BufferType):
dataJson = str(dataJson).decode("UTF8")
r['data'] = json.loads(dataJson)
except Exception as e:
logger.warn(
"Invalid JSON in data for pusher %d: %s, %s",
r['id'], dataJson, e.message,
)
pass
if isinstance(r['pushkey'], types.BufferType):
r['pushkey'] = str(r['pushkey']).decode("UTF8")
return rows
@defer.inlineCallbacks @defer.inlineCallbacks
def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey): def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
sql = ( def r(txn):
"SELECT * FROM pushers " sql = (
"WHERE app_id = ? AND pushkey = ?" "SELECT * FROM pushers"
) " WHERE app_id = ? AND pushkey = ?"
)
rows = yield self._execute_and_decode( txn.execute(sql, (app_id, pushkey,))
"get_pushers_by_app_id_and_pushkey", rows = self.cursor_to_dict(txn)
sql,
app_id, pushkey return self._decode_pushers_rows(rows)
rows = yield self.runInteraction(
"get_pushers_by_app_id_and_pushkey", r
) )
defer.returnValue(rows) defer.returnValue(rows)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_all_pushers(self): def get_all_pushers(self):
sql = ( def get_pushers(txn):
"SELECT * FROM pushers" txn.execute("SELECT * FROM pushers")
) rows = self.cursor_to_dict(txn)
rows = yield self._execute_and_decode("get_all_pushers", sql) return self._decode_pushers_rows(rows)
rows = yield self.runInteraction("get_all_pushers", get_pushers)
defer.returnValue(rows) defer.returnValue(rows)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -72,7 +100,7 @@ class PusherStore(SQLBaseStore):
device_display_name=device_display_name, device_display_name=device_display_name,
ts=pushkey_ts, ts=pushkey_ts,
lang=lang, lang=lang,
data=encode_canonical_json(data).decode("UTF-8"), data=encode_canonical_json(data),
), ),
insertion_values=dict( insertion_values=dict(
id=next_id, id=next_id,

View file

@ -181,7 +181,7 @@ class RegistrationStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at): def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
yield self._simple_upsert("user_threepids", { yield self._simple_upsert("user_threepids", {
"user": user_id, "user_id": user_id,
"medium": medium, "medium": medium,
"address": address, "address": address,
}, { }, {
@ -193,7 +193,7 @@ class RegistrationStore(SQLBaseStore):
def user_get_threepids(self, user_id): def user_get_threepids(self, user_id):
ret = yield self._simple_select_list( ret = yield self._simple_select_list(
"user_threepids", { "user_threepids", {
"user": user_id "user_id": user_id
}, },
['medium', 'address', 'validated_at', 'added_at'], ['medium', 'address', 'validated_at', 'added_at'],
'user_get_threepids' 'user_get_threepids'

View file

@ -75,6 +75,16 @@ class RoomStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
def get_public_room_ids(self):
return self._simple_select_onecol(
table="rooms",
keyvalues={
"is_public": True,
},
retcol="room_id",
desc="get_public_room_ids",
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_rooms(self, is_public): def get_rooms(self, is_public):
"""Retrieve a list of all public rooms. """Retrieve a list of all public rooms.
@ -186,14 +196,13 @@ class RoomStore(SQLBaseStore):
sql = ( sql = (
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e " "SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
"INNER JOIN current_state_events as c ON e.event_id = c.event_id " "INNER JOIN current_state_events as c ON e.event_id = c.event_id "
"INNER JOIN state_events as s ON e.event_id = s.event_id "
"WHERE c.room_id = ? " "WHERE c.room_id = ? "
) % { ) % {
"redacted": del_sql, "redacted": del_sql,
} }
sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')" sql += " AND ((c.type = 'm.room.name' AND c.state_key = '')"
sql += " OR s.type = 'm.room.aliases')" sql += " OR c.type = 'm.room.aliases')"
args = (room_id,) args = (room_id,)
results = yield self._execute_and_decode("get_current_state", sql, *args) results = yield self._execute_and_decode("get_current_state", sql, *args)

View file

@ -65,6 +65,7 @@ class RoomMemberStore(SQLBaseStore):
) )
self.get_rooms_for_user.invalidate(target_user_id) self.get_rooms_for_user.invalidate(target_user_id)
self.get_joined_hosts_for_room.invalidate(event.room_id)
def get_room_member(self, user_id, room_id): def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member. """Retrieve the current state of a room member.
@ -162,6 +163,7 @@ class RoomMemberStore(SQLBaseStore):
RoomsForUser(**r) for r in self.cursor_to_dict(txn) RoomsForUser(**r) for r in self.cursor_to_dict(txn)
] ]
@cached()
def get_joined_hosts_for_room(self, room_id): def get_joined_hosts_for_room(self, room_id):
return self.runInteraction( return self.runInteraction(
"get_joined_hosts_for_room", "get_joined_hosts_for_room",

View file

@ -0,0 +1,9 @@
CREATE TABLE user_threepids (
user_id TEXT NOT NULL,
medium TEXT NOT NULL,
address TEXT NOT NULL,
validated_at BIGINT NOT NULL,
added_at BIGINT NOT NULL,
CONSTRAINT user_medium_address UNIQUE (user_id, medium, address)
);
CREATE INDEX user_threepids_user_id ON user_threepids(user_id);

View file

@ -128,25 +128,18 @@ class StateStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""): def get_current_state(self, room_id, event_type=None, state_key=""):
del_sql = ( sql = (
"SELECT event_id FROM redactions WHERE redacts = e.event_id " "SELECT e.*, r.event_id FROM events as e"
"LIMIT 1" " LEFT JOIN redactions as r ON r.redacts = e.event_id"
" INNER JOIN current_state_events as c ON e.event_id = c.event_id"
" WHERE c.room_id = ? "
) )
sql = (
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
"INNER JOIN state_events as s ON e.event_id = s.event_id "
"WHERE c.room_id = ? "
) % {
"redacted": del_sql,
}
if event_type and state_key is not None: if event_type and state_key is not None:
sql += " AND s.type = ? AND s.state_key = ? " sql += " AND c.type = ? AND c.state_key = ? "
args = (room_id, event_type, state_key) args = (room_id, event_type, state_key)
elif event_type: elif event_type:
sql += " AND s.type = ?" sql += " AND c.type = ?"
args = (room_id, event_type) args = (room_id, event_type)
else: else:
args = (room_id, ) args = (room_id, )

View file

@ -30,15 +30,13 @@ class IdGenerator(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_next(self): def get_next(self):
if self._next_id is None:
yield self.store.runInteraction(
"IdGenerator_%s" % (self.table,),
self.get_next_txn,
)
with self._lock: with self._lock:
if not self._next_id:
res = yield self.store._execute_and_decode(
"IdGenerator_%s" % (self.table,),
"SELECT MAX(%s) as mx FROM %s" % (self.column, self.table,)
)
self._next_id = (res and res[0] and res[0]["mx"]) or 1
i = self._next_id i = self._next_id
self._next_id += 1 self._next_id += 1
defer.returnValue(i) defer.returnValue(i)
@ -86,10 +84,10 @@ class StreamIdGenerator(object):
with stream_id_gen.get_next_txn(txn) as stream_id: with stream_id_gen.get_next_txn(txn) as stream_id:
# ... persist event ... # ... persist event ...
""" """
with self._lock: if not self._current_max:
if not self._current_max: self._get_or_compute_current_max(txn)
self._compute_current_max(txn)
with self._lock:
self._current_max += 1 self._current_max += 1
next_id = self._current_max next_id = self._current_max
@ -110,22 +108,24 @@ class StreamIdGenerator(object):
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted. equal to it have been successfully persisted.
""" """
if not self._current_max:
yield store.runInteraction(
"_compute_current_max",
self._get_or_compute_current_max,
)
with self._lock: with self._lock:
if self._unfinished_ids: if self._unfinished_ids:
defer.returnValue(self._unfinished_ids[0] - 1) defer.returnValue(self._unfinished_ids[0] - 1)
if not self._current_max:
yield store.runInteraction(
"_compute_current_max",
self._compute_current_max,
)
defer.returnValue(self._current_max) defer.returnValue(self._current_max)
def _compute_current_max(self, txn): def _get_or_compute_current_max(self, txn):
txn.execute("SELECT MAX(stream_ordering) FROM events") with self._lock:
val, = txn.fetchone() txn.execute("SELECT MAX(stream_ordering) FROM events")
rows = txn.fetchall()
val, = rows[0]
self._current_max = int(val) if val else 1 self._current_max = int(val) if val else 1
return self._current_max return self._current_max