0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-30 18:54:05 +01:00

Merge pull request #123 from matrix-org/postgres

Postgres
This commit is contained in:
Erik Johnston 2015-04-28 14:24:08 +01:00
commit 9fbcf19188
82 changed files with 2906 additions and 823 deletions

34
docs/postgres.rst Normal file
View file

@ -0,0 +1,34 @@
Using Postgres
--------------
Set up client
=============
We need to have installed the postgres python connector ``psycopg2``. In the
virtual env::
sudo apt-get install libpq-dev
pip install psycopg2
Synapse config
==============
Add the following line to your config file::
database_config: <db_config_file>
Where ``<db_config_file>`` is the file name that points to a yaml file of the
following form::
name: psycopg2
args:
user: <user>
password: <pass>
database: <db>
host: <host>
cp_min: 5
cp_max: 10
All key, values in ``args`` are passed to the ``psycopg2.connect(..)``
function, except keys beginning with ``cp_``, which are consumed by the twisted
adbapi connection pool.

View file

@ -0,0 +1,732 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, reactor
from twisted.enterprise import adbapi
from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine
import argparse
import curses
import logging
import sys
import time
import traceback
import yaml
logger = logging.getLogger("port_from_sqlite_to_postgres")
BOOLEAN_COLUMNS = {
"events": ["processed", "outlier"],
"rooms": ["is_public"],
"event_edges": ["is_state"],
"presence_list": ["accepted"],
}
APPEND_ONLY_TABLES = [
"event_content_hashes",
"event_reference_hashes",
"event_signatures",
"event_edge_hashes",
"events",
"event_json",
"state_events",
"room_memberships",
"feedback",
"topics",
"room_names",
"rooms",
"local_media_repository",
"local_media_repository_thumbnails",
"remote_media_cache",
"remote_media_cache_thumbnails",
"redactions",
"event_edges",
"event_auth",
"received_transactions",
"sent_transactions",
"transaction_id_to_pdu",
"users",
"state_groups",
"state_groups_state",
"event_to_state_groups",
"rejections",
]
end_error_exec_info = None
class Store(object):
"""This object is used to pull out some of the convenience API from the
Storage layer.
*All* database interactions should go through this object.
"""
def __init__(self, db_pool, engine):
self.db_pool = db_pool
self.database_engine = engine
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
_simple_insert = SQLBaseStore.__dict__["_simple_insert"]
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
_simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
_simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
_simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"]
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
_execute_and_decode = SQLBaseStore.__dict__["_execute_and_decode"]
def runInteraction(self, desc, func, *args, **kwargs):
def r(conn):
try:
i = 0
N = 5
while True:
try:
txn = conn.cursor()
return func(
LoggingTransaction(txn, desc, self.database_engine),
*args, **kwargs
)
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
logger.warn("[TXN DEADLOCK] {%s} %d/%d", desc, i, N)
if i < N:
i += 1
conn.rollback()
continue
raise
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", desc, e)
raise
return self.db_pool.runWithConnection(r)
def execute(self, f, *args, **kwargs):
return self.runInteraction(f.__name__, f, *args, **kwargs)
def execute_sql(self, sql, *args):
def r(txn):
txn.execute(sql, args)
return txn.fetchall()
return self.runInteraction("execute_sql", r)
def insert_many_txn(self, txn, table, headers, rows):
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in headers),
", ".join("%s" for _ in headers)
)
try:
txn.executemany(sql, rows)
except:
logger.exception(
"Failed to insert: %s",
table,
)
raise
class Porter(object):
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
@defer.inlineCallbacks
def setup_table(self, table):
if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting.
next_chunk = yield self.postgres_store._simple_select_one_onecol(
table="port_from_sqlite3",
keyvalues={"table_name": table},
retcol="rowid",
allow_none=True,
)
total_to_port = None
if next_chunk is None:
if table == "sent_transactions":
next_chunk, already_ported, total_to_port = (
yield self._setup_sent_transactions()
)
else:
yield self.postgres_store._simple_insert(
table="port_from_sqlite3",
values={"table_name": table, "rowid": 1}
)
next_chunk = 1
already_ported = 0
if total_to_port is None:
already_ported, total_to_port = yield self._get_total_count_to_port(
table, next_chunk
)
else:
def delete_all(txn):
txn.execute(
"DELETE FROM port_from_sqlite3 WHERE table_name = %s",
(table,)
)
txn.execute("TRUNCATE %s CASCADE" % (table,))
yield self.postgres_store.execute(delete_all)
yield self.postgres_store._simple_insert(
table="port_from_sqlite3",
values={"table_name": table, "rowid": 0}
)
next_chunk = 1
already_ported, total_to_port = yield self._get_total_count_to_port(
table, next_chunk
)
defer.returnValue((table, already_ported, total_to_port, next_chunk))
@defer.inlineCallbacks
def handle_table(self, table, postgres_size, table_size, next_chunk):
if not table_size:
return
self.progress.add_table(table, postgres_size, table_size)
select = (
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
% (table,)
)
while True:
def r(txn):
txn.execute(select, (next_chunk, self.batch_size,))
rows = txn.fetchall()
headers = [column[0] for column in txn.description]
return headers, rows
headers, rows = yield self.sqlite_store.runInteraction("select", r)
if rows:
next_chunk = rows[-1][0] + 1
self._convert_rows(table, headers, rows)
def insert(txn):
self.postgres_store.insert_many_txn(
txn, table, headers[1:], rows
)
self.postgres_store._simple_update_one_txn(
txn,
table="port_from_sqlite3",
keyvalues={"table_name": table},
updatevalues={"rowid": next_chunk},
)
yield self.postgres_store.execute(insert)
postgres_size += len(rows)
self.progress.update(table, postgres_size)
else:
return
def setup_db(self, db_config, database_engine):
db_conn = database_engine.module.connect(
**{
k: v for k, v in db_config.get("args", {}).items()
if not k.startswith("cp_")
}
)
database_engine.prepare_database(db_conn)
db_conn.commit()
@defer.inlineCallbacks
def run(self):
try:
sqlite_db_pool = adbapi.ConnectionPool(
self.sqlite_config["name"],
**self.sqlite_config["args"]
)
postgres_db_pool = adbapi.ConnectionPool(
self.postgres_config["name"],
**self.postgres_config["args"]
)
sqlite_engine = create_engine("sqlite3")
postgres_engine = create_engine("psycopg2")
self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
self.postgres_store = Store(postgres_db_pool, postgres_engine)
# Step 1. Set up databases.
self.progress.set_state("Preparing SQLite3")
self.setup_db(sqlite_config, sqlite_engine)
self.progress.set_state("Preparing PostgreSQL")
self.setup_db(postgres_config, postgres_engine)
# Step 2. Get tables.
self.progress.set_state("Fetching tables")
sqlite_tables = yield self.sqlite_store._simple_select_onecol(
table="sqlite_master",
keyvalues={
"type": "table",
},
retcol="name",
)
postgres_tables = yield self.postgres_store._simple_select_onecol(
table="information_schema.tables",
keyvalues={
"table_schema": "public",
},
retcol="distinct table_name",
)
tables = set(sqlite_tables) & set(postgres_tables)
self.progress.set_state("Creating tables")
logger.info("Found %d tables", len(tables))
def create_port_table(txn):
txn.execute(
"CREATE TABLE port_from_sqlite3 ("
" table_name varchar(100) NOT NULL UNIQUE,"
" rowid bigint NOT NULL"
")"
)
try:
yield self.postgres_store.runInteraction(
"create_port_table", create_port_table
)
except Exception as e:
logger.info("Failed to create port table: %s", e)
self.progress.set_state("Setting up")
# Set up tables.
setup_res = yield defer.gatherResults(
[
self.setup_table(table)
for table in tables
if table not in ["schema_version", "applied_schema_deltas"]
and not table.startswith("sqlite_")
],
consumeErrors=True,
)
# Process tables.
yield defer.gatherResults(
[
self.handle_table(*res)
for res in setup_res
],
consumeErrors=True,
)
self.progress.done()
except:
global end_error_exec_info
end_error_exec_info = sys.exc_info()
logger.exception("")
finally:
reactor.stop()
def _convert_rows(self, table, headers, rows):
bool_col_names = BOOLEAN_COLUMNS.get(table, [])
bool_cols = [
i for i, h in enumerate(headers) if h in bool_col_names
]
def conv(j, col):
if j in bool_cols:
return bool(col)
return col
for i, row in enumerate(rows):
rows[i] = tuple(
self.postgres_store.database_engine.encode_parameter(
conv(j, col)
)
for j, col in enumerate(row)
if j > 0
)
@defer.inlineCallbacks
def _setup_sent_transactions(self):
# Only save things from the last day
yesterday = int(time.time()*1000) - 86400000
# And save the max transaction id from each destination
select = (
"SELECT rowid, * FROM sent_transactions WHERE rowid IN ("
"SELECT max(rowid) FROM sent_transactions"
" GROUP BY destination"
")"
)
def r(txn):
txn.execute(select)
rows = txn.fetchall()
headers = [column[0] for column in txn.description]
ts_ind = headers.index('ts')
return headers, [r for r in rows if r[ts_ind] < yesterday]
headers, rows = yield self.sqlite_store.runInteraction(
"select", r,
)
self._convert_rows("sent_transactions", headers, rows)
inserted_rows = len(rows)
max_inserted_rowid = max(r[0] for r in rows)
def insert(txn):
self.postgres_store.insert_many_txn(
txn, "sent_transactions", headers[1:], rows
)
yield self.postgres_store.execute(insert)
def get_start_id(txn):
txn.execute(
"SELECT rowid FROM sent_transactions WHERE ts >= ?"
" ORDER BY rowid ASC LIMIT 1",
(yesterday,)
)
rows = txn.fetchall()
if rows:
return rows[0][0]
else:
return 1
next_chunk = yield self.sqlite_store.execute(get_start_id)
next_chunk = max(max_inserted_rowid + 1, next_chunk)
yield self.postgres_store._simple_insert(
table="port_from_sqlite3",
values={"table_name": "sent_transactions", "rowid": next_chunk}
)
def get_sent_table_size(txn):
txn.execute(
"SELECT count(*) FROM sent_transactions"
" WHERE ts >= ?",
(yesterday,)
)
size, = txn.fetchone()
return int(size)
remaining_count = yield self.sqlite_store.execute(
get_sent_table_size
)
total_count = remaining_count + inserted_rows
defer.returnValue((next_chunk, inserted_rows, total_count))
@defer.inlineCallbacks
def _get_remaining_count_to_port(self, table, next_chunk):
rows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
next_chunk,
)
defer.returnValue(rows[0][0])
@defer.inlineCallbacks
def _get_already_ported_count(self, table):
rows = yield self.postgres_store.execute_sql(
"SELECT count(*) FROM %s" % (table,),
)
defer.returnValue(rows[0][0])
@defer.inlineCallbacks
def _get_total_count_to_port(self, table, next_chunk):
remaining, done = yield defer.gatherResults(
[
self._get_remaining_count_to_port(table, next_chunk),
self._get_already_ported_count(table),
],
consumeErrors=True,
)
remaining = int(remaining) if remaining else 0
done = int(done) if done else 0
defer.returnValue((done, remaining + done))
##############################################
###### The following is simply UI stuff ######
##############################################
class Progress(object):
"""Used to report progress of the port
"""
def __init__(self):
self.tables = {}
self.start_time = int(time.time())
def add_table(self, table, cur, size):
self.tables[table] = {
"start": cur,
"num_done": cur,
"total": size,
"perc": int(cur * 100 / size),
}
def update(self, table, num_done):
data = self.tables[table]
data["num_done"] = num_done
data["perc"] = int(num_done * 100 / data["total"])
def done(self):
pass
class CursesProgress(Progress):
"""Reports progress to a curses window
"""
def __init__(self, stdscr):
self.stdscr = stdscr
curses.use_default_colors()
curses.curs_set(0)
curses.init_pair(1, curses.COLOR_RED, -1)
curses.init_pair(2, curses.COLOR_GREEN, -1)
self.last_update = 0
self.finished = False
self.total_processed = 0
self.total_remaining = 0
super(CursesProgress, self).__init__()
def update(self, table, num_done):
super(CursesProgress, self).update(table, num_done)
self.total_processed = 0
self.total_remaining = 0
for table, data in self.tables.items():
self.total_processed += data["num_done"] - data["start"]
self.total_remaining += data["total"] - data["num_done"]
self.render()
def render(self, force=False):
now = time.time()
if not force and now - self.last_update < 0.2:
# reactor.callLater(1, self.render)
return
self.stdscr.clear()
rows, cols = self.stdscr.getmaxyx()
duration = int(now) - int(self.start_time)
minutes, seconds = divmod(duration, 60)
duration_str = '%02dm %02ds' % (minutes, seconds,)
if self.finished:
status = "Time spent: %s (Done!)" % (duration_str,)
else:
if self.total_processed > 0:
left = float(self.total_remaining) / self.total_processed
est_remaining = (int(now) - self.start_time) * left
est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60)
else:
est_remaining_str = "Unknown"
status = (
"Time spent: %s (est. remaining: %s)"
% (duration_str, est_remaining_str,)
)
self.stdscr.addstr(
0, 0,
status,
curses.A_BOLD,
)
max_len = max([len(t) for t in self.tables.keys()])
left_margin = 5
middle_space = 1
items = self.tables.items()
items.sort(
key=lambda i: (i[1]["perc"], i[0]),
)
for i, (table, data) in enumerate(items):
if i + 2 >= rows:
break
perc = data["perc"]
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
self.stdscr.addstr(
i+2, left_margin + max_len - len(table),
table,
curses.A_BOLD | color,
)
size = 20
progress = "[%s%s]" % (
"#" * int(perc*size/100),
" " * (size - int(perc*size/100)),
)
self.stdscr.addstr(
i+2, left_margin + max_len + middle_space,
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
)
if self.finished:
self.stdscr.addstr(
rows-1, 0,
"Press any key to exit...",
)
self.stdscr.refresh()
self.last_update = time.time()
def done(self):
self.finished = True
self.render(True)
self.stdscr.getch()
def set_state(self, state):
self.stdscr.clear()
self.stdscr.addstr(
0, 0,
state + "...",
curses.A_BOLD,
)
self.stdscr.refresh()
class TerminalProgress(Progress):
"""Just prints progress to the terminal
"""
def update(self, table, num_done):
super(TerminalProgress, self).update(table, num_done)
data = self.tables[table]
print "%s: %d%% (%d/%d)" % (
table, data["perc"],
data["num_done"], data["total"],
)
def set_state(self, state):
print state + "..."
##############################################
##############################################
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-v", action='store_true')
parser.add_argument("--curses", action='store_true')
parser.add_argument("--sqlite-database")
parser.add_argument(
"--postgres-config", type=argparse.FileType('r'),
)
parser.add_argument("--batch-size", type=int, default=1000)
args = parser.parse_args()
logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
}
if args.curses:
logging_config["filename"] = "port-synapse.log"
logging.basicConfig(**logging_config)
sqlite_config = {
"name": "sqlite3",
"args": {
"database": args.sqlite_database,
"cp_min": 1,
"cp_max": 1,
"check_same_thread": False,
},
}
postgres_config = yaml.safe_load(args.postgres_config)
def start(stdscr=None):
if stdscr:
progress = CursesProgress(stdscr)
else:
progress = TerminalProgress()
porter = Porter(
sqlite_config=sqlite_config,
postgres_config=postgres_config,
progress=progress,
batch_size=args.batch_size,
)
reactor.callWhenRunning(porter.run)
reactor.run()
if args.curses:
curses.wrapper(start)
else:
start()
if end_error_exec_info:
exc_type, exc_value, exc_traceback = end_error_exec_info
traceback.print_exception(exc_type, exc_value, exc_traceback)

View file

@ -17,9 +17,9 @@
import sys import sys
sys.dont_write_bytecode = True sys.dont_write_bytecode = True
from synapse.storage.engines import create_engine
from synapse.storage import ( from synapse.storage import (
prepare_database, prepare_sqlite3_database, are_all_users_on_domain, are_all_users_on_domain, UpgradeDatabaseException,
UpgradeDatabaseException,
) )
from synapse.server import HomeServer from synapse.server import HomeServer
@ -58,9 +58,9 @@ import os
import re import re
import resource import resource
import subprocess import subprocess
import sqlite3
logger = logging.getLogger(__name__)
logger = logging.getLogger("synapse.app.homeserver")
class SynapseHomeServer(HomeServer): class SynapseHomeServer(HomeServer):
@ -104,13 +104,11 @@ class SynapseHomeServer(HomeServer):
return None return None
def build_db_pool(self): def build_db_pool(self):
name = self.db_config["name"]
return adbapi.ConnectionPool( return adbapi.ConnectionPool(
"sqlite3", self.get_db_name(), name,
check_same_thread=False, **self.db_config.get("args", {})
cp_min=1,
cp_max=1,
cp_openfun=prepare_database, # Prepare the database for each conn
# so that :memory: sqlite works
) )
def create_resource_tree(self, redirect_root_to_web_client): def create_resource_tree(self, redirect_root_to_web_client):
@ -242,9 +240,9 @@ class SynapseHomeServer(HomeServer):
) )
logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port) logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port)
def run_startup_checks(self, db_conn): def run_startup_checks(self, db_conn, database_engine):
all_users_native = are_all_users_on_domain( all_users_native = are_all_users_on_domain(
db_conn, self.hostname db_conn.cursor(), database_engine, self.hostname
) )
if not all_users_native: if not all_users_native:
sys.stderr.write( sys.stderr.write(
@ -368,15 +366,20 @@ def setup(config_options):
tls_context_factory = context_factory.ServerContextFactory(config) tls_context_factory = context_factory.ServerContextFactory(config)
database_engine = create_engine(config.database_config["name"])
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
hs = SynapseHomeServer( hs = SynapseHomeServer(
config.server_name, config.server_name,
domain_with_port=domain_with_port, domain_with_port=domain_with_port,
upload_dir=os.path.abspath("uploads"), upload_dir=os.path.abspath("uploads"),
db_name=config.database_path, db_name=config.database_path,
db_config=config.database_config,
tls_context_factory=tls_context_factory, tls_context_factory=tls_context_factory,
config=config, config=config,
content_addr=config.content_addr, content_addr=config.content_addr,
version_string=version_string, version_string=version_string,
database_engine=database_engine,
) )
hs.create_resource_tree( hs.create_resource_tree(
@ -388,10 +391,17 @@ def setup(config_options):
logger.info("Preparing database: %s...", db_name) logger.info("Preparing database: %s...", db_name)
try: try:
with sqlite3.connect(db_name) as db_conn: db_conn = database_engine.module.connect(
prepare_sqlite3_database(db_conn) **{
prepare_database(db_conn) k: v for k, v in config.database_config.get("args", {}).items()
hs.run_startup_checks(db_conn) if not k.startswith("cp_")
}
)
database_engine.prepare_database(db_conn)
hs.run_startup_checks(db_conn, database_engine)
db_conn.commit()
except UpgradeDatabaseException: except UpgradeDatabaseException:
sys.stderr.write( sys.stderr.write(
"\nFailed to upgrade database.\n" "\nFailed to upgrade database.\n"

View file

@ -15,6 +15,7 @@
from ._base import Config from ._base import Config
import os import os
import yaml
class DatabaseConfig(Config): class DatabaseConfig(Config):
@ -26,18 +27,45 @@ class DatabaseConfig(Config):
self.database_path = self.abspath(args.database_path) self.database_path = self.abspath(args.database_path)
self.event_cache_size = self.parse_size(args.event_cache_size) self.event_cache_size = self.parse_size(args.event_cache_size)
if args.database_config:
with open(args.database_config) as f:
self.database_config = yaml.safe_load(f)
else:
self.database_config = {
"name": "sqlite3",
"args": {
"database": self.database_path,
},
}
name = self.database_config.get("name", None)
if name == "psycopg2":
pass
elif name == "sqlite3":
self.database_config.setdefault("args", {}).update({
"cp_min": 1,
"cp_max": 1,
"check_same_thread": False,
})
else:
raise RuntimeError("Unsupported database type '%s'" % (name,))
@classmethod @classmethod
def add_arguments(cls, parser): def add_arguments(cls, parser):
super(DatabaseConfig, cls).add_arguments(parser) super(DatabaseConfig, cls).add_arguments(parser)
db_group = parser.add_argument_group("database") db_group = parser.add_argument_group("database")
db_group.add_argument( db_group.add_argument(
"-d", "--database-path", default="homeserver.db", "-d", "--database-path", default="homeserver.db",
help="The database name." metavar="SQLITE_DATABASE_PATH", help="The database name."
) )
db_group.add_argument( db_group.add_argument(
"--event-cache-size", default="100K", "--event-cache-size", default="100K",
help="Number of events to cache in memory." help="Number of events to cache in memory."
) )
db_group.add_argument(
"--database-config", default=None,
help="Location of the database configuration file."
)
@classmethod @classmethod
def generate_config(cls, args, config_dir_path): def generate_config(cls, args, config_dir_path):

View file

@ -16,7 +16,6 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import LimitExceededError, SynapseError from synapse.api.errors import LimitExceededError, SynapseError
from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID from synapse.types import UserID
@ -58,8 +57,6 @@ class BaseHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder): def _create_new_client_event(self, builder):
yield run_on_reactor()
latest_ret = yield self.store.get_latest_events_in_room( latest_ret = yield self.store.get_latest_events_in_room(
builder.room_id, builder.room_id,
) )
@ -101,8 +98,6 @@ class BaseHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_new_client_event(self, event, context, extra_destinations=[], def handle_new_client_event(self, event, context, extra_destinations=[],
extra_users=[], suppress_auth=False): extra_users=[], suppress_auth=False):
yield run_on_reactor()
# We now need to go and hit out to wherever we need to hit out to. # We now need to go and hit out to wherever we need to hit out to.
if not suppress_auth: if not suppress_auth:
@ -143,7 +138,9 @@ class BaseHandler(object):
) )
# Don't block waiting on waking up all the listeners. # Don't block waiting on waking up all the listeners.
d = self.notifier.on_new_room_event(event, extra_users=extra_users) notify_d = self.notifier.on_new_room_event(
event, extra_users=extra_users
)
def log_failure(f): def log_failure(f):
logger.warn( logger.warn(
@ -151,8 +148,10 @@ class BaseHandler(object):
event.event_id, f.value event.event_id, f.value
) )
d.addErrback(log_failure) notify_d.addErrback(log_failure)
yield federation_handler.handle_new_event( fed_d = federation_handler.handle_new_event(
event, destinations=destinations, event, destinations=destinations,
) )
fed_d.addErrback(log_failure)

View file

@ -179,7 +179,7 @@ class FederationHandler(BaseHandler):
# it's probably a good idea to mark it as not in retry-state # it's probably a good idea to mark it as not in retry-state
# for sending (although this is a bit of a leap) # for sending (although this is a bit of a leap)
retry_timings = yield self.store.get_destination_retry_timings(origin) retry_timings = yield self.store.get_destination_retry_timings(origin)
if (retry_timings and retry_timings.retry_last_ts): if retry_timings and retry_timings["retry_last_ts"]:
self.store.set_destination_retry_timings(origin, 0, 0) self.store.set_destination_retry_timings(origin, 0, 0)
room = yield self.store.get_room(event.room_id) room = yield self.store.get_room(event.room_id)

View file

@ -53,7 +53,7 @@ class LoginHandler(BaseHandler):
logger.warn("Attempted to login as %s but they do not exist", user) logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
stored_hash = user_info[0]["password_hash"] stored_hash = user_info["password_hash"]
if bcrypt.checkpw(password, stored_hash): if bcrypt.checkpw(password, stored_hash):
# generate an access token and store it. # generate an access token and store it.
token = self.reg_handler._generate_token(user) token = self.reg_handler._generate_token(user)

View file

@ -274,7 +274,8 @@ class MessageHandler(BaseHandler):
if limit is None: if limit is None:
limit = 10 limit = 10
for event in room_list: @defer.inlineCallbacks
def handle_room(event):
d = { d = {
"room_id": event.room_id, "room_id": event.room_id,
"membership": event.membership, "membership": event.membership,
@ -290,12 +291,19 @@ class MessageHandler(BaseHandler):
rooms_ret.append(d) rooms_ret.append(d)
if event.membership != Membership.JOIN: if event.membership != Membership.JOIN:
continue return
try: try:
messages, token = yield self.store.get_recent_events_for_room( (messages, token), current_state = yield defer.gatherResults(
event.room_id, [
limit=limit, self.store.get_recent_events_for_room(
end_token=now_token.room_key, event.room_id,
limit=limit,
end_token=now_token.room_key,
),
self.state_handler.get_current_state(
event.room_id
),
]
) )
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])
@ -311,9 +319,6 @@ class MessageHandler(BaseHandler):
"end": end_token.to_string(), "end": end_token.to_string(),
} }
current_state = yield self.state_handler.get_current_state(
event.room_id
)
d["state"] = [ d["state"] = [
serialize_event(c, time_now, as_client_event) serialize_event(c, time_now, as_client_event)
for c in current_state.values() for c in current_state.values()
@ -321,6 +326,11 @@ class MessageHandler(BaseHandler):
except: except:
logger.exception("Failed to get snapshot") logger.exception("Failed to get snapshot")
yield defer.gatherResults(
[handle_room(e) for e in room_list],
consumeErrors=True
)
ret = { ret = {
"rooms": rooms_ret, "rooms": rooms_ret,
"presence": presence, "presence": presence,

View file

@ -311,25 +311,6 @@ class RoomMemberHandler(BaseHandler):
# paginating # paginating
defer.returnValue(chunk_data) defer.returnValue(chunk_data)
@defer.inlineCallbacks
def get_room_member(self, room_id, member_user_id, auth_user_id):
"""Retrieve a room member from a room.
Args:
room_id : The room the member is in.
member_user_id : The member's user ID
auth_user_id : The user ID of the user making this request.
Returns:
The room member, or None if this member does not exist.
Raises:
SynapseError if something goes wrong.
"""
yield self.auth.check_joined_room(room_id, auth_user_id)
member = yield self.store.get_room_member(user_id=member_user_id,
room_id=room_id)
defer.returnValue(member)
@defer.inlineCallbacks @defer.inlineCallbacks
def change_membership(self, event, context, do_auth=True): def change_membership(self, event, context, do_auth=True):
""" Change the membership status of a user in a room. """ Change the membership status of a user in a room.

View file

@ -98,7 +98,7 @@ class _NotificationListener(object):
try: try:
notifier.clock.cancel_call_later(self.timer) notifier.clock.cancel_call_later(self.timer)
except: except:
logger.exception("Failed to cancel notifier timer") logger.warn("Failed to cancel notifier timer")
class Notifier(object): class Notifier(object):

View file

@ -19,10 +19,7 @@ from twisted.internet import defer
from httppusher import HttpPusher from httppusher import HttpPusher
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from syutil.jsonutil import encode_canonical_json
import logging import logging
import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -52,8 +49,6 @@ class PusherPool:
@defer.inlineCallbacks @defer.inlineCallbacks
def start(self): def start(self):
pushers = yield self.store.get_all_pushers() pushers = yield self.store.get_all_pushers()
for p in pushers:
p['data'] = json.loads(p['data'])
self._start_pushers(pushers) self._start_pushers(pushers)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -131,7 +126,7 @@ class PusherPool:
pushkey=pushkey, pushkey=pushkey,
pushkey_ts=self.hs.get_clock().time_msec(), pushkey_ts=self.hs.get_clock().time_msec(),
lang=lang, lang=lang,
data=encode_canonical_json(data).decode("UTF-8"), data=data,
) )
self._refresh_pusher(app_id, pushkey, user_name) self._refresh_pusher(app_id, pushkey, user_name)
@ -162,13 +157,13 @@ class PusherPool:
resultlist = yield self.store.get_pushers_by_app_id_and_pushkey( resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(
app_id, pushkey app_id, pushkey
) )
p = None p = None
for r in resultlist: for r in resultlist:
if r['user_name'] == user_name: if r['user_name'] == user_name:
p = r p = r
if p: if p:
p['data'] = json.loads(p['data'])
self._start_pushers([p]) self._start_pushers([p])

View file

@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 15 SCHEMA_VERSION = 16
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))
@ -104,14 +104,16 @@ class DataStore(RoomMemberStore, RoomStore,
self.client_ip_last_seen.prefill(*key + (now,)) self.client_ip_last_seen.prefill(*key + (now,))
yield self._simple_insert( yield self._simple_upsert(
"user_ips", "user_ips",
{ keyvalues={
"user": user.to_string(), "user_id": user.to_string(),
"access_token": access_token, "access_token": access_token,
"device_id": device_id,
"ip": ip, "ip": ip,
"user_agent": user_agent, "user_agent": user_agent,
},
values={
"device_id": device_id,
"last_seen": now, "last_seen": now,
}, },
desc="insert_client_ip", desc="insert_client_ip",
@ -120,7 +122,7 @@ class DataStore(RoomMemberStore, RoomStore,
def get_user_ip_and_agents(self, user): def get_user_ip_and_agents(self, user):
return self._simple_select_list( return self._simple_select_list(
table="user_ips", table="user_ips",
keyvalues={"user": user.to_string()}, keyvalues={"user_id": user.to_string()},
retcols=[ retcols=[
"device_id", "access_token", "ip", "user_agent", "last_seen" "device_id", "access_token", "ip", "user_agent", "last_seen"
], ],
@ -148,21 +150,23 @@ class UpgradeDatabaseException(PrepareDatabaseException):
pass pass
def prepare_database(db_conn): def prepare_database(db_conn, database_engine):
"""Prepares a database for usage. Will either create all necessary tables """Prepares a database for usage. Will either create all necessary tables
or upgrade from an older schema version. or upgrade from an older schema version.
""" """
try: try:
cur = db_conn.cursor() cur = db_conn.cursor()
version_info = _get_or_create_schema_state(cur) version_info = _get_or_create_schema_state(cur, database_engine)
if version_info: if version_info:
user_version, delta_files, upgraded = version_info user_version, delta_files, upgraded = version_info
_upgrade_existing_database(cur, user_version, delta_files, upgraded) _upgrade_existing_database(
cur, user_version, delta_files, upgraded, database_engine
)
else: else:
_setup_new_database(cur) _setup_new_database(cur, database_engine)
cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
cur.close() cur.close()
db_conn.commit() db_conn.commit()
@ -171,7 +175,7 @@ def prepare_database(db_conn):
raise raise
def _setup_new_database(cur): def _setup_new_database(cur, database_engine):
"""Sets up the database by finding a base set of "full schemas" and then """Sets up the database by finding a base set of "full schemas" and then
applying any necessary deltas. applying any necessary deltas.
@ -225,31 +229,30 @@ def _setup_new_database(cur):
directory_entries = os.listdir(sql_dir) directory_entries = os.listdir(sql_dir)
sql_script = "BEGIN TRANSACTION;\n"
for filename in fnmatch.filter(directory_entries, "*.sql"): for filename in fnmatch.filter(directory_entries, "*.sql"):
sql_loc = os.path.join(sql_dir, filename) sql_loc = os.path.join(sql_dir, filename)
logger.debug("Applying schema %s", sql_loc) logger.debug("Applying schema %s", sql_loc)
sql_script += read_schema(sql_loc) executescript(cur, sql_loc)
sql_script += "\n"
sql_script += "COMMIT TRANSACTION;"
cur.executescript(sql_script)
cur.execute( cur.execute(
"INSERT OR REPLACE INTO schema_version (version, upgraded)" database_engine.convert_param_style(
" VALUES (?,?)", "INSERT INTO schema_version (version, upgraded)"
(max_current_ver, False) " VALUES (?,?)"
),
(max_current_ver, False,)
) )
_upgrade_existing_database( _upgrade_existing_database(
cur, cur,
current_version=max_current_ver, current_version=max_current_ver,
applied_delta_files=[], applied_delta_files=[],
upgraded=False upgraded=False,
database_engine=database_engine,
) )
def _upgrade_existing_database(cur, current_version, applied_delta_files, def _upgrade_existing_database(cur, current_version, applied_delta_files,
upgraded): upgraded, database_engine):
"""Upgrades an existing database. """Upgrades an existing database.
Delta files can either be SQL stored in *.sql files, or python modules Delta files can either be SQL stored in *.sql files, or python modules
@ -305,6 +308,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
if not upgraded: if not upgraded:
start_ver += 1 start_ver += 1
logger.debug("applied_delta_files: %s", applied_delta_files)
for v in range(start_ver, SCHEMA_VERSION + 1): for v in range(start_ver, SCHEMA_VERSION + 1):
logger.debug("Upgrading schema to v%d", v) logger.debug("Upgrading schema to v%d", v)
@ -321,6 +326,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
directory_entries.sort() directory_entries.sort()
for file_name in directory_entries: for file_name in directory_entries:
relative_path = os.path.join(str(v), file_name) relative_path = os.path.join(str(v), file_name)
logger.debug("Found file: %s", relative_path)
if relative_path in applied_delta_files: if relative_path in applied_delta_files:
continue continue
@ -342,9 +348,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module.run_upgrade(cur) module.run_upgrade(cur)
elif ext == ".sql": elif ext == ".sql":
# A plain old .sql file, just read and execute it # A plain old .sql file, just read and execute it
delta_schema = read_schema(absolute_path)
logger.debug("Applying schema %s", relative_path) logger.debug("Applying schema %s", relative_path)
cur.executescript(delta_schema) executescript(cur, absolute_path)
else: else:
# Not a valid delta file. # Not a valid delta file.
logger.warn( logger.warn(
@ -356,24 +361,82 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
# Mark as done. # Mark as done.
cur.execute( cur.execute(
"INSERT INTO applied_schema_deltas (version, file)" database_engine.convert_param_style(
" VALUES (?,?)", "INSERT INTO applied_schema_deltas (version, file)"
" VALUES (?,?)",
),
(v, relative_path) (v, relative_path)
) )
cur.execute( cur.execute(
"INSERT OR REPLACE INTO schema_version (version, upgraded)" database_engine.convert_param_style(
" VALUES (?,?)", "REPLACE INTO schema_version (version, upgraded)"
" VALUES (?,?)",
),
(v, True) (v, True)
) )
def _get_or_create_schema_state(txn): def get_statements(f):
statement_buffer = ""
in_comment = False # If we're in a /* ... */ style comment
for line in f:
line = line.strip()
if in_comment:
# Check if this line contains an end to the comment
comments = line.split("*/", 1)
if len(comments) == 1:
continue
line = comments[1]
in_comment = False
# Remove inline block comments
line = re.sub(r"/\*.*\*/", " ", line)
# Does this line start a comment?
comments = line.split("/*", 1)
if len(comments) > 1:
line = comments[0]
in_comment = True
# Deal with line comments
line = line.split("--", 1)[0]
line = line.split("//", 1)[0]
# Find *all* semicolons. We need to treat first and last entry
# specially.
statements = line.split(";")
# We must prepend statement_buffer to the first statement
first_statement = "%s %s" % (
statement_buffer.strip(),
statements[0].strip()
)
statements[0] = first_statement
# Every entry, except the last, is a full statement
for statement in statements[:-1]:
yield statement.strip()
# The last entry did *not* end in a semicolon, so we store it for the
# next semicolon we find
statement_buffer = statements[-1].strip()
def executescript(txn, schema_path):
with open(schema_path, 'r') as f:
for statement in get_statements(f):
txn.execute(statement)
def _get_or_create_schema_state(txn, database_engine):
# Bluntly try creating the schema_version tables.
schema_path = os.path.join( schema_path = os.path.join(
dir_path, "schema", "schema_version.sql", dir_path, "schema", "schema_version.sql",
) )
create_schema = read_schema(schema_path) executescript(txn, schema_path)
txn.executescript(create_schema)
txn.execute("SELECT version, upgraded FROM schema_version") txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone() row = txn.fetchone()
@ -382,10 +445,13 @@ def _get_or_create_schema_state(txn):
if current_version: if current_version:
txn.execute( txn.execute(
"SELECT file FROM applied_schema_deltas WHERE version >= ?", database_engine.convert_param_style(
"SELECT file FROM applied_schema_deltas WHERE version >= ?"
),
(current_version,) (current_version,)
) )
return current_version, txn.fetchall(), upgraded applied_deltas = [d for d, in txn.fetchall()]
return current_version, applied_deltas, upgraded
return None return None
@ -417,17 +483,19 @@ def prepare_sqlite3_database(db_conn):
if row and row[0]: if row and row[0]:
db_conn.execute( db_conn.execute(
"INSERT OR REPLACE INTO schema_version (version, upgraded)" "REPLACE INTO schema_version (version, upgraded)"
" VALUES (?,?)", " VALUES (?,?)",
(row[0], False) (row[0], False)
) )
def are_all_users_on_domain(txn, domain): def are_all_users_on_domain(txn, database_engine, domain):
sql = "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?" sql = database_engine.convert_param_style(
"SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
)
pat = "%:" + domain pat = "%:" + domain
cursor = txn.execute(sql, (pat,)) txn.execute(sql, (pat,))
num_not_matching = cursor.fetchall()[0][0] num_not_matching = txn.fetchall()[0][0]
if num_not_matching == 0: if num_not_matching == 0:
return True return True
return False return False

View file

@ -22,6 +22,8 @@ from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
from synapse.util.lrucache import LruCache from synapse.util.lrucache import LruCache
import synapse.metrics import synapse.metrics
from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer from twisted.internet import defer
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
@ -145,11 +147,12 @@ class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object """An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute() passed to the constructor. Adds logging and metrics to the .execute()
method.""" method."""
__slots__ = ["txn", "name"] __slots__ = ["txn", "name", "database_engine"]
def __init__(self, txn, name): def __init__(self, txn, name, database_engine):
object.__setattr__(self, "txn", txn) object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name) object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine)
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.txn, name) return getattr(self.txn, name)
@ -161,26 +164,32 @@ class LoggingTransaction(object):
# TODO(paul): Maybe use 'info' and 'debug' for values? # TODO(paul): Maybe use 'info' and 'debug' for values?
sql_logger.debug("[SQL] {%s} %s", self.name, sql) sql_logger.debug("[SQL] {%s} %s", self.name, sql)
try: sql = self.database_engine.convert_param_style(sql)
if args and args[0]:
values = args[0] if args and args[0]:
args = list(args)
args[0] = [
self.database_engine.encode_parameter(a) for a in args[0]
]
try:
sql_logger.debug( sql_logger.debug(
"[SQL values] {%s} " + ", ".join(("<%r>",) * len(values)), "[SQL values] {%s} " + ", ".join(("<%r>",) * len(args[0])),
self.name, self.name,
*values *args[0]
) )
except: except:
# Don't let logging failures stop SQL from working # Don't let logging failures stop SQL from working
pass pass
start = time.time() * 1000 start = time.time() * 1000
try: try:
return self.txn.execute( return self.txn.execute(
sql, *args, **kwargs sql, *args, **kwargs
) )
except: except Exception as e:
logger.exception("[SQL FAIL] {%s}", self.name) logger.debug("[SQL FAIL] {%s} %s", self.name, e)
raise raise
finally: finally:
msecs = (time.time() * 1000) - start msecs = (time.time() * 1000) - start
sql_logger.debug("[SQL time] {%s} %f", self.name, msecs) sql_logger.debug("[SQL time] {%s} %f", self.name, msecs)
@ -245,6 +254,14 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
max_entries=hs.config.event_cache_size) max_entries=hs.config.event_cache_size)
self.database_engine = hs.database_engine
self._stream_id_gen = StreamIdGenerator()
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self)
def start_profiling(self): def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec() self._previous_loop_ts = self._clock.time_msec()
@ -281,8 +298,11 @@ class SQLBaseStore(object):
start_time = time.time() * 1000 start_time = time.time() * 1000
def inner_func(txn, *args, **kwargs): def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context: with LoggingContext("runInteraction") as context:
if self.database_engine.is_connection_closed(conn):
conn.reconnect()
current_context.copy_to(context) current_context.copy_to(context)
start = time.time() * 1000 start = time.time() * 1000
txn_id = self._TXN_ID txn_id = self._TXN_ID
@ -296,9 +316,48 @@ class SQLBaseStore(object):
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
transaction_logger.debug("[TXN START] {%s}", name) transaction_logger.debug("[TXN START] {%s}", name)
try: try:
return func(LoggingTransaction(txn, name), *args, **kwargs) i = 0
except: N = 5
logger.exception("[TXN FAIL] {%s}", name) while True:
try:
txn = conn.cursor()
return func(
LoggingTransaction(txn, name, self.database_engine),
*args, **kwargs
)
except self.database_engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
logger.warn(
"[TXN OPERROR] {%s} %s %d/%d",
name, e, i, N
)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
raise
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise raise
finally: finally:
end = time.time() * 1000 end = time.time() * 1000
@ -311,7 +370,7 @@ class SQLBaseStore(object):
sql_txn_timer.inc_by(duration, desc) sql_txn_timer.inc_by(duration, desc)
with PreserveLoggingContext(): with PreserveLoggingContext():
result = yield self._db_pool.runInteraction( result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs inner_func, *args, **kwargs
) )
defer.returnValue(result) defer.returnValue(result)
@ -342,11 +401,11 @@ class SQLBaseStore(object):
The result of decoder(results) The result of decoder(results)
""" """
def interaction(txn): def interaction(txn):
cursor = txn.execute(query, args) txn.execute(query, args)
if decoder: if decoder:
return decoder(cursor) return decoder(txn)
else: else:
return cursor.fetchall() return txn.fetchall()
return self.runInteraction(desc, interaction) return self.runInteraction(desc, interaction)
@ -356,27 +415,29 @@ class SQLBaseStore(object):
# "Simple" SQL API methods that operate on a single table with no JOINs, # "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns. # no complex WHERE clauses, just a dict of values for columns.
def _simple_insert(self, table, values, or_replace=False, or_ignore=False, @defer.inlineCallbacks
def _simple_insert(self, table, values, or_ignore=False,
desc="_simple_insert"): desc="_simple_insert"):
"""Executes an INSERT query on the named table. """Executes an INSERT query on the named table.
Args: Args:
table : string giving the table name table : string giving the table name
values : dict of new column names and values for them values : dict of new column names and values for them
or_replace : bool; if True performs an INSERT OR REPLACE
""" """
return self.runInteraction( try:
desc, yield self.runInteraction(
self._simple_insert_txn, table, values, or_replace=or_replace, desc,
or_ignore=or_ignore, self._simple_insert_txn, table, values,
) )
except self.database_engine.module.IntegrityError:
# We have to do or_ignore flag at this layer, since we can't reuse
# a cursor after we receive an error from the db.
if not or_ignore:
raise
@log_function @log_function
def _simple_insert_txn(self, txn, table, values, or_replace=False, def _simple_insert_txn(self, txn, table, values):
or_ignore=False): sql = "INSERT INTO %s (%s) VALUES(%s)" % (
sql = "%s INTO %s (%s) VALUES(%s)" % (
("INSERT OR REPLACE" if or_replace else
"INSERT OR IGNORE" if or_ignore else "INSERT"),
table, table,
", ".join(k for k in values), ", ".join(k for k in values),
", ".join("?" for k in values) ", ".join("?" for k in values)
@ -388,22 +449,26 @@ class SQLBaseStore(object):
) )
txn.execute(sql, values.values()) txn.execute(sql, values.values())
return txn.lastrowid
def _simple_upsert(self, table, keyvalues, values, desc="_simple_upsert"): def _simple_upsert(self, table, keyvalues, values,
insertion_values={}, desc="_simple_upsert"):
""" """
Args: Args:
table (str): The table to upsert into table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values values (dict): The nonunique columns and their new values
insertion_values (dict): key/values to use when inserting
Returns: A deferred Returns: A deferred
""" """
return self.runInteraction( return self.runInteraction(
desc, desc,
self._simple_upsert_txn, table, keyvalues, values self._simple_upsert_txn, table, keyvalues, values, insertion_values,
) )
def _simple_upsert_txn(self, txn, table, keyvalues, values): def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={}):
# We need to lock the table :(
self.database_engine.lock_table(txn, table)
# Try to update # Try to update
sql = "UPDATE %s SET %s WHERE %s" % ( sql = "UPDATE %s SET %s WHERE %s" % (
table, table,
@ -422,6 +487,7 @@ class SQLBaseStore(object):
allvalues = {} allvalues = {}
allvalues.update(keyvalues) allvalues.update(keyvalues)
allvalues.update(values) allvalues.update(values)
allvalues.update(insertion_values)
sql = "INSERT INTO %s (%s) VALUES (%s)" % ( sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table, table,
@ -489,8 +555,7 @@ class SQLBaseStore(object):
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol): def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
sql = ( sql = (
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s " "SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
"ORDER BY rowid asc"
) % { ) % {
"retcol": retcol, "retcol": retcol,
"table": table, "table": table,
@ -548,14 +613,14 @@ class SQLBaseStore(object):
retcols : list of strings giving the names of the columns to return retcols : list of strings giving the names of the columns to return
""" """
if keyvalues: if keyvalues:
sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % ( sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols), ", ".join(retcols),
table, table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues) " AND ".join("%s = ?" % (k, ) for k in keyvalues)
) )
txn.execute(sql, keyvalues.values()) txn.execute(sql, keyvalues.values())
else: else:
sql = "SELECT %s FROM %s ORDER BY rowid asc" % ( sql = "SELECT %s FROM %s" % (
", ".join(retcols), ", ".join(retcols),
table table
) )
@ -607,10 +672,10 @@ class SQLBaseStore(object):
def _simple_select_one_txn(self, txn, table, keyvalues, retcols, def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
allow_none=False): allow_none=False):
select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % ( select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols), ", ".join(retcols),
table, table,
" AND ".join("%s = ?" % (k) for k in keyvalues) " AND ".join("%s = ?" % (k,) for k in keyvalues)
) )
txn.execute(select_sql, keyvalues.values()) txn.execute(select_sql, keyvalues.values())
@ -648,6 +713,11 @@ class SQLBaseStore(object):
updatevalues=updatevalues, updatevalues=updatevalues,
) )
# if txn.rowcount == 0:
# raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
return ret return ret
return self.runInteraction(desc, func) return self.runInteraction(desc, func)
@ -860,6 +930,12 @@ class SQLBaseStore(object):
result = txn.fetchone() result = txn.fetchone()
return result[0] if result else None return result[0] if result else None
def get_next_stream_id(self):
with self._next_stream_id_lock:
i = self._next_stream_id
self._next_stream_id += 1
return i
class _RollbackButIsFineException(Exception): class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying """ This exception is used to rollback a transaction without implying
@ -883,7 +959,7 @@ class Table(object):
_select_where_clause = "SELECT %s FROM %s WHERE %s" _select_where_clause = "SELECT %s FROM %s WHERE %s"
_select_clause = "SELECT %s FROM %s" _select_clause = "SELECT %s FROM %s"
_insert_clause = "INSERT OR REPLACE INTO %s (%s) VALUES (%s)" _insert_clause = "REPLACE INTO %s (%s) VALUES (%s)"
@classmethod @classmethod
def select_statement(cls, where_clause=None): def select_statement(cls, where_clause=None):

View file

@ -366,11 +366,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
new_txn_id = max(highest_txn_id, last_txn_id) + 1 new_txn_id = max(highest_txn_id, last_txn_id) + 1
# Insert new txn into txn table # Insert new txn into txn table
event_ids = [e.event_id for e in events] event_ids = json.dumps([e.event_id for e in events])
txn.execute( txn.execute(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) " "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)", "VALUES(?,?,?)",
(service.id, new_txn_id, json.dumps(event_ids)) (service.id, new_txn_id, event_ids)
) )
return AppServiceTransaction( return AppServiceTransaction(
service=service, id=new_txn_id, events=events service=service, id=new_txn_id, events=events

View file

@ -21,8 +21,6 @@ from twisted.internet import defer
from collections import namedtuple from collections import namedtuple
import sqlite3
RoomAliasMapping = namedtuple( RoomAliasMapping = namedtuple(
"RoomAliasMapping", "RoomAliasMapping",
@ -91,7 +89,7 @@ class DirectoryStore(SQLBaseStore):
}, },
desc="create_room_alias_association", desc="create_room_alias_association",
) )
except sqlite3.IntegrityError: except self.database_engine.module.IntegrityError:
raise SynapseError( raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string() 409, "Room alias %s already exists" % room_alias.to_string()
) )
@ -120,12 +118,12 @@ class DirectoryStore(SQLBaseStore):
defer.returnValue(room_id) defer.returnValue(room_id)
def _delete_room_alias_txn(self, txn, room_alias): def _delete_room_alias_txn(self, txn, room_alias):
cursor = txn.execute( txn.execute(
"SELECT room_id FROM room_aliases WHERE room_alias = ?", "SELECT room_id FROM room_aliases WHERE room_alias = ?",
(room_alias.to_string(),) (room_alias.to_string(),)
) )
res = cursor.fetchone() res = txn.fetchone()
if res: if res:
room_id = res[0] room_id = res[0]
else: else:

View file

@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .postgres import PostgresEngine
from .sqlite3 import Sqlite3Engine
import importlib
SUPPORTED_MODULE = {
"sqlite3": Sqlite3Engine,
"psycopg2": PostgresEngine,
}
def create_engine(name):
engine_class = SUPPORTED_MODULE.get(name, None)
if engine_class:
module = importlib.import_module(name)
return engine_class(module)
raise RuntimeError(
"Unsupported database engine '%s'" % (name,)
)

View file

@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage import prepare_database
class PostgresEngine(object):
def __init__(self, database_module):
self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE)
def convert_param_style(self, sql):
return sql.replace("?", "%s")
def encode_parameter(self, param):
return param
def on_new_connection(self, db_conn):
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
def prepare_database(self, db_conn):
prepare_database(db_conn, self)
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
return error.pgcode in ["40001", "40P01"]
return False
def is_connection_closed(self, conn):
return bool(conn)
def lock_table(self, txn, table):
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))

View file

@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage import prepare_database, prepare_sqlite3_database
class Sqlite3Engine(object):
def __init__(self, database_module):
self.module = database_module
def convert_param_style(self, sql):
return sql
def encode_parameter(self, param):
return param
def on_new_connection(self, db_conn):
self.prepare_database(db_conn)
def prepare_database(self, db_conn):
prepare_sqlite3_database(db_conn)
prepare_database(db_conn, self)
def is_deadlock(self, error):
return False
def is_connection_closed(self, conn):
return False
def lock_table(self, txn, table):
return

View file

@ -153,7 +153,7 @@ class EventFederationStore(SQLBaseStore):
results = self._get_prev_events_and_state( results = self._get_prev_events_and_state(
txn, txn,
event_id, event_id,
is_state=1, is_state=True,
) )
return [(e_id, h, ) for e_id, h, _ in results] return [(e_id, h, ) for e_id, h, _ in results]
@ -164,7 +164,7 @@ class EventFederationStore(SQLBaseStore):
} }
if is_state is not None: if is_state is not None:
keyvalues["is_state"] = is_state keyvalues["is_state"] = bool(is_state)
res = self._simple_select_list_txn( res = self._simple_select_list_txn(
txn, txn,
@ -242,7 +242,6 @@ class EventFederationStore(SQLBaseStore):
"room_id": room_id, "room_id": room_id,
"min_depth": depth, "min_depth": depth,
}, },
or_replace=True,
) )
def _handle_prev_events(self, txn, outlier, event_id, prev_events, def _handle_prev_events(self, txn, outlier, event_id, prev_events,
@ -260,9 +259,8 @@ class EventFederationStore(SQLBaseStore):
"event_id": event_id, "event_id": event_id,
"prev_event_id": e_id, "prev_event_id": e_id,
"room_id": room_id, "room_id": room_id,
"is_state": 0, "is_state": False,
}, },
or_ignore=True,
) )
# Update the extremities table if this is not an outlier. # Update the extremities table if this is not an outlier.
@ -281,19 +279,19 @@ class EventFederationStore(SQLBaseStore):
# We only insert as a forward extremity the new event if there are # We only insert as a forward extremity the new event if there are
# no other events that reference it as a prev event # no other events that reference it as a prev event
query = ( query = (
"INSERT OR IGNORE INTO %(table)s (event_id, room_id) " "SELECT 1 FROM event_edges WHERE prev_event_id = ?"
"SELECT ?, ? WHERE NOT EXISTS (" )
"SELECT 1 FROM %(event_edges)s WHERE "
"prev_event_id = ? "
")"
) % {
"table": "event_forward_extremities",
"event_edges": "event_edges",
}
logger.debug("query: %s", query) txn.execute(query, (event_id,))
txn.execute(query, (event_id, room_id, event_id)) if not txn.fetchone():
query = (
"INSERT INTO event_forward_extremities"
" (event_id, room_id)"
" VALUES (?, ?)"
)
txn.execute(query, (event_id, room_id))
# Insert all the prev_events as a backwards thing, they'll get # Insert all the prev_events as a backwards thing, they'll get
# deleted in a second if they're incorrect anyway. # deleted in a second if they're incorrect anyway.
@ -306,7 +304,6 @@ class EventFederationStore(SQLBaseStore):
"event_id": e_id, "event_id": e_id,
"room_id": room_id, "room_id": room_id,
}, },
or_ignore=True,
) )
# Also delete from the backwards extremities table all ones that # Also delete from the backwards extremities table all ones that
@ -400,7 +397,7 @@ class EventFederationStore(SQLBaseStore):
query = ( query = (
"SELECT prev_event_id FROM event_edges " "SELECT prev_event_id FROM event_edges "
"WHERE room_id = ? AND event_id = ? AND is_state = 0 " "WHERE room_id = ? AND event_id = ? AND is_state = ? "
"LIMIT ?" "LIMIT ?"
) )
@ -409,7 +406,7 @@ class EventFederationStore(SQLBaseStore):
for event_id in front: for event_id in front:
txn.execute( txn.execute(
query, query,
(room_id, event_id, limit - len(event_results)) (room_id, event_id, False, limit - len(event_results))
) )
for e_id, in txn.fetchall(): for e_id, in txn.fetchall():

View file

@ -52,7 +52,6 @@ class EventsStore(SQLBaseStore):
is_new_state=is_new_state, is_new_state=is_new_state,
current_state=current_state, current_state=current_state,
) )
self.get_room_events_max_id.invalidate()
except _RollbackButIsFineException: except _RollbackButIsFineException:
pass pass
@ -96,12 +95,22 @@ class EventsStore(SQLBaseStore):
# Remove the any existing cache entries for the event_id # Remove the any existing cache entries for the event_id
self._invalidate_get_event_cache(event.event_id) self._invalidate_get_event_cache(event.event_id)
if stream_ordering is None:
with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
return self._persist_event_txn(
txn, event, context, backfilled,
stream_ordering=stream_ordering,
is_new_state=is_new_state,
current_state=current_state,
)
# We purposefully do this first since if we include a `current_state` # We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table # key, we *want* to update the `current_state_events` table
if current_state: if current_state:
txn.execute( self._simple_delete_txn(
"DELETE FROM current_state_events WHERE room_id = ?", txn,
(event.room_id,) table="current_state_events",
keyvalues={"room_id": event.room_id},
) )
for s in current_state: for s in current_state:
@ -114,7 +123,6 @@ class EventsStore(SQLBaseStore):
"type": s.type, "type": s.type,
"state_key": s.state_key, "state_key": s.state_key,
}, },
or_replace=True,
) )
if event.is_state() and is_new_state: if event.is_state() and is_new_state:
@ -128,7 +136,6 @@ class EventsStore(SQLBaseStore):
"type": event.type, "type": event.type,
"state_key": event.state_key, "state_key": event.state_key,
}, },
or_replace=True,
) )
for prev_state_id, _ in event.prev_state: for prev_state_id, _ in event.prev_state:
@ -151,14 +158,6 @@ class EventsStore(SQLBaseStore):
event.depth event.depth
) )
self._handle_prev_events(
txn,
outlier=outlier,
event_id=event.event_id,
prev_events=event.prev_events,
room_id=event.room_id,
)
have_persisted = self._simple_select_one_onecol_txn( have_persisted = self._simple_select_one_onecol_txn(
txn, txn,
table="event_json", table="event_json",
@ -169,7 +168,7 @@ class EventsStore(SQLBaseStore):
metadata_json = encode_canonical_json( metadata_json = encode_canonical_json(
event.internal_metadata.get_dict() event.internal_metadata.get_dict()
) ).decode("UTF-8")
# If we have already persisted this event, we don't need to do any # If we have already persisted this event, we don't need to do any
# more processing. # more processing.
@ -185,23 +184,29 @@ class EventsStore(SQLBaseStore):
) )
txn.execute( txn.execute(
sql, sql,
(metadata_json.decode("UTF-8"), event.event_id,) (metadata_json, event.event_id,)
) )
sql = ( sql = (
"UPDATE events SET outlier = 0" "UPDATE events SET outlier = ?"
" WHERE event_id = ?" " WHERE event_id = ?"
) )
txn.execute( txn.execute(
sql, sql,
(event.event_id,) (False, event.event_id,)
) )
return return
self._handle_prev_events(
txn,
outlier=outlier,
event_id=event.event_id,
prev_events=event.prev_events,
room_id=event.room_id,
)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
self._store_room_member_txn(txn, event) self._store_room_member_txn(txn, event)
elif event.type == EventTypes.Feedback:
self._store_feedback_txn(txn, event)
elif event.type == EventTypes.Name: elif event.type == EventTypes.Name:
self._store_room_name_txn(txn, event) self._store_room_name_txn(txn, event)
elif event.type == EventTypes.Topic: elif event.type == EventTypes.Topic:
@ -224,10 +229,9 @@ class EventsStore(SQLBaseStore):
values={ values={
"event_id": event.event_id, "event_id": event.event_id,
"room_id": event.room_id, "room_id": event.room_id,
"internal_metadata": metadata_json.decode("UTF-8"), "internal_metadata": metadata_json,
"json": encode_canonical_json(event_dict).decode("UTF-8"), "json": encode_canonical_json(event_dict).decode("UTF-8"),
}, },
or_replace=True,
) )
content = encode_canonical_json( content = encode_canonical_json(
@ -245,9 +249,6 @@ class EventsStore(SQLBaseStore):
"depth": event.depth, "depth": event.depth,
} }
if stream_ordering is not None:
vals["stream_ordering"] = stream_ordering
unrec = { unrec = {
k: v k: v
for k, v in event.get_dict().items() for k, v in event.get_dict().items()
@ -264,67 +265,24 @@ class EventsStore(SQLBaseStore):
unrec unrec
).decode("UTF-8") ).decode("UTF-8")
try: sql = (
self._simple_insert_txn( "INSERT INTO events"
txn, " (stream_ordering, topological_ordering, event_id, type,"
"events", " room_id, content, processed, outlier, depth)"
vals, " VALUES (?,?,?,?,?,?,?,?,?)"
or_replace=(not outlier), )
or_ignore=bool(outlier),
txn.execute(
sql,
(
stream_ordering, event.depth, event.event_id, event.type,
event.room_id, content, True, outlier, event.depth
) )
except: )
logger.warn(
"Failed to persist, probably duplicate: %s",
event.event_id,
exc_info=True,
)
raise _RollbackButIsFineException("_persist_event")
if context.rejected: if context.rejected:
self._store_rejections_txn(txn, event.event_id, context.rejected) self._store_rejections_txn(txn, event.event_id, context.rejected)
if event.is_state():
vals = {
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
}
# TODO: How does this work with backfilling?
if hasattr(event, "replaces_state"):
vals["prev_state"] = event.replaces_state
self._simple_insert_txn(
txn,
"state_events",
vals,
)
if is_new_state and not context.rejected:
self._simple_insert_txn(
txn,
"current_state_events",
{
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
)
for e_id, h in event.prev_state:
self._simple_insert_txn(
txn,
table="event_edges",
values={
"event_id": event.event_id,
"prev_event_id": e_id,
"room_id": event.room_id,
"is_state": 1,
},
)
for hash_alg, hash_base64 in event.hashes.items(): for hash_alg, hash_base64 in event.hashes.items():
hash_bytes = decode_base64(hash_base64) hash_bytes = decode_base64(hash_base64)
self._store_event_content_hash_txn( self._store_event_content_hash_txn(
@ -354,6 +312,50 @@ class EventsStore(SQLBaseStore):
txn, event.event_id, ref_alg, ref_hash_bytes txn, event.event_id, ref_alg, ref_hash_bytes
) )
if event.is_state():
vals = {
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
}
# TODO: How does this work with backfilling?
if hasattr(event, "replaces_state"):
vals["prev_state"] = event.replaces_state
self._simple_insert_txn(
txn,
"state_events",
vals,
)
for e_id, h in event.prev_state:
self._simple_insert_txn(
txn,
table="event_edges",
values={
"event_id": event.event_id,
"prev_event_id": e_id,
"room_id": event.room_id,
"is_state": True,
},
)
if is_new_state and not context.rejected:
self._simple_upsert_txn(
txn,
"current_state_events",
keyvalues={
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
values={
"event_id": event.event_id,
}
)
def _store_redaction(self, txn, event): def _store_redaction(self, txn, event):
# invalidate the cache for the redacted event # invalidate the cache for the redacted event
self._invalidate_get_event_cache(event.redacts) self._invalidate_get_event_cache(event.redacts)

View file

@ -57,16 +57,18 @@ class KeyStore(SQLBaseStore):
OpenSSL.crypto.FILETYPE_ASN1, tls_certificate OpenSSL.crypto.FILETYPE_ASN1, tls_certificate
) )
fingerprint = hashlib.sha256(tls_certificate_bytes).hexdigest() fingerprint = hashlib.sha256(tls_certificate_bytes).hexdigest()
return self._simple_insert( return self._simple_upsert(
table="server_tls_certificates", table="server_tls_certificates",
values={ keyvalues={
"server_name": server_name, "server_name": server_name,
"fingerprint": fingerprint, "fingerprint": fingerprint,
},
values={
"from_server": from_server, "from_server": from_server,
"ts_added_ms": time_now_ms, "ts_added_ms": time_now_ms,
"tls_certificate": buffer(tls_certificate_bytes), "tls_certificate": buffer(tls_certificate_bytes),
}, },
or_ignore=True, desc="store_server_certificate",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -107,14 +109,16 @@ class KeyStore(SQLBaseStore):
ts_now_ms (int): The time now in milliseconds ts_now_ms (int): The time now in milliseconds
verification_key (VerifyKey): The NACL verify key. verification_key (VerifyKey): The NACL verify key.
""" """
return self._simple_insert( return self._simple_upsert(
table="server_signature_keys", table="server_signature_keys",
values={ keyvalues={
"server_name": server_name, "server_name": server_name,
"key_id": "%s:%s" % (verify_key.alg, verify_key.version), "key_id": "%s:%s" % (verify_key.alg, verify_key.version),
},
values={
"from_server": from_server, "from_server": from_server,
"ts_added_ms": time_now_ms, "ts_added_ms": time_now_ms,
"verify_key": buffer(verify_key.encode()), "verify_key": buffer(verify_key.encode()),
}, },
or_ignore=True, desc="store_server_verify_key",
) )

View file

@ -57,6 +57,7 @@ class PresenceStore(SQLBaseStore):
values={"observed_user_id": observed_localpart, values={"observed_user_id": observed_localpart,
"observer_user_id": observer_userid}, "observer_user_id": observer_userid},
desc="allow_presence_visible", desc="allow_presence_visible",
or_ignore=True,
) )
def disallow_presence_visible(self, observed_localpart, observer_userid): def disallow_presence_visible(self, observed_localpart, observer_userid):

View file

@ -154,7 +154,7 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority)) txn.execute(sql, (user_name, priority_class, new_rule_priority))
# now insert the new rule # now insert the new rule
sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" (" sql = "INSERT INTO "+PushRuleTable.table_name+" ("
sql += ",".join(new_rule.keys())+") VALUES (" sql += ",".join(new_rule.keys())+") VALUES ("
sql += ", ".join(["?" for _ in new_rule.keys()])+")" sql += ", ".join(["?" for _ in new_rule.keys()])+")"
@ -183,7 +183,7 @@ class PushRuleStore(SQLBaseStore):
new_rule['priority_class'] = priority_class new_rule['priority_class'] = priority_class
new_rule['priority'] = new_prio new_rule['priority'] = new_prio
sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" (" sql = "INSERT INTO "+PushRuleTable.table_name+" ("
sql += ",".join(new_rule.keys())+") VALUES (" sql += ",".join(new_rule.keys())+") VALUES ("
sql += ", ".join(["?" for _ in new_rule.keys()])+")" sql += ", ".join(["?" for _ in new_rule.keys()])+")"

View file

@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
from ._base import SQLBaseStore, Table from ._base import SQLBaseStore, Table
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from syutil.jsonutil import encode_canonical_json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,50 +28,35 @@ logger = logging.getLogger(__name__)
class PusherStore(SQLBaseStore): class PusherStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey): def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
cols = ",".join(PushersTable.fields)
sql = ( sql = (
"SELECT "+cols+" FROM pushers " "SELECT * FROM pushers "
"WHERE app_id = ? AND pushkey = ?" "WHERE app_id = ? AND pushkey = ?"
) )
rows = yield self._execute( rows = yield self._execute_and_decode(
"get_pushers_by_app_id_and_pushkey", None, sql, "get_pushers_by_app_id_and_pushkey",
sql,
app_id, pushkey app_id, pushkey
) )
ret = [ defer.returnValue(rows)
{
k: r[i] for i, k in enumerate(PushersTable.fields)
}
for r in rows
]
print ret
defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_all_pushers(self): def get_all_pushers(self):
cols = ",".join(PushersTable.fields)
sql = ( sql = (
"SELECT "+cols+" FROM pushers" "SELECT * FROM pushers"
) )
rows = yield self._execute("get_all_pushers", None, sql) rows = yield self._execute_and_decode("get_all_pushers", sql)
ret = [ defer.returnValue(rows)
{
k: r[i] for i, k in enumerate(PushersTable.fields)
}
for r in rows
]
defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_pusher(self, user_name, access_token, profile_tag, kind, app_id, def add_pusher(self, user_name, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, app_display_name, device_display_name,
pushkey, pushkey_ts, lang, data): pushkey, pushkey_ts, lang, data):
try: try:
next_id = yield self._pushers_id_gen.get_next()
yield self._simple_upsert( yield self._simple_upsert(
PushersTable.table_name, PushersTable.table_name,
dict( dict(
@ -87,7 +72,10 @@ class PusherStore(SQLBaseStore):
device_display_name=device_display_name, device_display_name=device_display_name,
ts=pushkey_ts, ts=pushkey_ts,
lang=lang, lang=lang,
data=data data=encode_canonical_json(data).decode("UTF-8"),
),
insertion_values=dict(
id=next_id,
), ),
desc="add_pusher", desc="add_pusher",
) )
@ -135,23 +123,3 @@ class PusherStore(SQLBaseStore):
class PushersTable(Table): class PushersTable(Table):
table_name = "pushers" table_name = "pushers"
fields = [
"id",
"user_name",
"access_token",
"kind",
"profile_tag",
"app_id",
"app_display_name",
"device_display_name",
"pushkey",
"ts",
"lang",
"data",
"last_token",
"last_success",
"failing_since"
]
EntryType = collections.namedtuple("PusherEntry", fields)

View file

@ -15,8 +15,6 @@
from twisted.internet import defer from twisted.internet import defer
from sqlite3 import IntegrityError
from synapse.api.errors import StoreError, Codes from synapse.api.errors import StoreError, Codes
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore, cached
@ -39,17 +37,13 @@ class RegistrationStore(SQLBaseStore):
Raises: Raises:
StoreError if there was a problem adding this. StoreError if there was a problem adding this.
""" """
row = yield self._simple_select_one( next_id = yield self._access_tokens_id_gen.get_next()
"users", {"name": user_id}, ["id"],
desc="add_access_token_to_user",
)
if not row:
raise StoreError(400, "Bad user ID supplied.")
row_id = row["id"]
yield self._simple_insert( yield self._simple_insert(
"access_tokens", "access_tokens",
{ {
"user_id": row_id, "id": next_id,
"user_id": user_id,
"token": token "token": token
}, },
desc="add_access_token_to_user", desc="add_access_token_to_user",
@ -74,25 +68,33 @@ class RegistrationStore(SQLBaseStore):
def _register(self, txn, user_id, token, password_hash): def _register(self, txn, user_id, token, password_hash):
now = int(self.clock.time()) now = int(self.clock.time())
next_id = self._access_tokens_id_gen.get_next_txn(txn)
try: try:
txn.execute("INSERT INTO users(name, password_hash, creation_ts) " txn.execute("INSERT INTO users(name, password_hash, creation_ts) "
"VALUES (?,?,?)", "VALUES (?,?,?)",
[user_id, password_hash, now]) [user_id, password_hash, now])
except IntegrityError: except self.database_engine.module.IntegrityError:
raise StoreError( raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE 400, "User ID already taken.", errcode=Codes.USER_IN_USE
) )
# it's possible for this to get a conflict, but only for a single user # it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID # since tokens are namespaced based on their user ID
txn.execute("INSERT INTO access_tokens(user_id, token) " + txn.execute(
"VALUES (?,?)", [txn.lastrowid, token]) "INSERT INTO access_tokens(id, user_id, token)"
" VALUES (?,?,?)",
(next_id, user_id, token,)
)
def get_user_by_id(self, user_id): def get_user_by_id(self, user_id):
query = ("SELECT users.id, users.name, users.password_hash FROM users" return self._simple_select_one(
" WHERE users.name = ?") table="users",
return self._execute( keyvalues={
"get_user_by_id", self.cursor_to_dict, query, user_id "name": user_id,
},
retcols=["name", "password_hash"],
allow_none=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -165,12 +167,12 @@ class RegistrationStore(SQLBaseStore):
"SELECT users.name, users.admin," "SELECT users.name, users.admin,"
" access_tokens.device_id, access_tokens.id as token_id" " access_tokens.device_id, access_tokens.id as token_id"
" FROM users" " FROM users"
" INNER JOIN access_tokens on users.id = access_tokens.user_id" " INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?" " WHERE token = ?"
) )
cursor = txn.execute(sql, (token,)) txn.execute(sql, (token,))
rows = self.cursor_to_dict(cursor) rows = self.cursor_to_dict(txn)
if rows: if rows:
return rows[0] return rows[0]

View file

@ -72,6 +72,7 @@ class RoomStore(SQLBaseStore):
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcols=RoomsTable.fields, retcols=RoomsTable.fields,
desc="get_room", desc="get_room",
allow_none=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -102,24 +103,37 @@ class RoomStore(SQLBaseStore):
"ON c.event_id = room_names.event_id " "ON c.event_id = room_names.event_id "
) )
# We use non printing ascii character US () as a seperator # We use non printing ascii character US (\x1F) as a separator
sql = ( sql = (
"SELECT r.room_id, n.name, t.topic, " "SELECT r.room_id, max(n.name), max(t.topic)"
"group_concat(a.room_alias, '') " " FROM rooms AS r"
"FROM rooms AS r " " LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id"
"LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id " " LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id"
"LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id " " WHERE r.is_public = ?"
"INNER JOIN room_aliases AS a ON a.room_id = r.room_id " " GROUP BY r.room_id"
"WHERE r.is_public = ? "
"GROUP BY r.room_id "
) % { ) % {
"topic": topic_subquery, "topic": topic_subquery,
"name": name_subquery, "name": name_subquery,
} }
c = txn.execute(sql, (is_public,)) txn.execute(sql, (is_public,))
return c.fetchall() rows = txn.fetchall()
for i, row in enumerate(rows):
room_id = row[0]
aliases = self._simple_select_onecol_txn(
txn,
table="room_aliases",
keyvalues={
"room_id": room_id
},
retcol="room_alias",
)
rows[i] = list(row) + [aliases]
return rows
rows = yield self.runInteraction( rows = yield self.runInteraction(
"get_rooms", f "get_rooms", f
@ -130,9 +144,10 @@ class RoomStore(SQLBaseStore):
"room_id": r[0], "room_id": r[0],
"name": r[1], "name": r[1],
"topic": r[2], "topic": r[2],
"aliases": r[3].split(""), "aliases": r[3],
} }
for r in rows for r in rows
if r[3] # We only return rooms that have at least one alias.
] ]
defer.returnValue(ret) defer.returnValue(ret)

View file

@ -40,7 +40,6 @@ class RoomMemberStore(SQLBaseStore):
""" """
try: try:
target_user_id = event.state_key target_user_id = event.state_key
domain = UserID.from_string(target_user_id).domain
except: except:
logger.exception( logger.exception(
"Failed to parse target_user_id=%s", target_user_id "Failed to parse target_user_id=%s", target_user_id
@ -65,42 +64,8 @@ class RoomMemberStore(SQLBaseStore):
} }
) )
# Update room hosts table
if event.membership == Membership.JOIN:
sql = (
"INSERT OR IGNORE INTO room_hosts (room_id, host) "
"VALUES (?, ?)"
)
txn.execute(sql, (event.room_id, domain))
elif event.membership != Membership.INVITE:
# Check if this was the last person to have left.
member_events = self._get_members_query_txn(
txn,
where_clause=("c.room_id = ? AND m.membership = ?"
" AND m.user_id != ?"),
where_values=(event.room_id, Membership.JOIN, target_user_id,)
)
joined_domains = set()
for e in member_events:
try:
joined_domains.add(
UserID.from_string(e.state_key).domain
)
except:
# FIXME: How do we deal with invalid user ids in the db?
logger.exception("Invalid user_id: %s", event.state_key)
if domain not in joined_domains:
sql = (
"DELETE FROM room_hosts WHERE room_id = ? AND host = ?"
)
txn.execute(sql, (event.room_id, domain))
self.get_rooms_for_user.invalidate(target_user_id) self.get_rooms_for_user.invalidate(target_user_id)
@defer.inlineCallbacks
def get_room_member(self, user_id, room_id): def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member. """Retrieve the current state of a room member.
@ -110,41 +75,27 @@ class RoomMemberStore(SQLBaseStore):
Returns: Returns:
Deferred: Results in a MembershipEvent or None. Deferred: Results in a MembershipEvent or None.
""" """
rows = yield self._get_members_by_dict({ def f(txn):
"e.room_id": room_id, events = self._get_members_events_txn(
"m.user_id": user_id, txn,
}) room_id,
user_id=user_id,
)
defer.returnValue(rows[0] if rows else None) return events[0] if events else None
def _get_room_member(self, txn, user_id, room_id): return self.runInteraction("get_room_member", f)
sql = (
"SELECT e.* FROM events as e"
" INNER JOIN room_memberships as m"
" ON e.event_id = m.event_id"
" INNER JOIN current_state_events as c"
" ON m.event_id = c.event_id"
" WHERE m.user_id = ? and e.room_id = ?"
" LIMIT 1"
)
txn.execute(sql, (user_id, room_id))
rows = self.cursor_to_dict(txn)
if rows:
return self._parse_events_txn(txn, rows)[0]
else:
return None
def get_users_in_room(self, room_id): def get_users_in_room(self, room_id):
def f(txn): def f(txn):
sql = (
"SELECT m.user_id FROM room_memberships as m" rows = self._get_members_rows_txn(
" INNER JOIN current_state_events as c" txn,
" ON m.event_id = c.event_id" room_id=room_id,
" WHERE m.membership = ? AND m.room_id = ?" membership=Membership.JOIN,
) )
txn.execute(sql, (Membership.JOIN, room_id)) return [r["user_id"] for r in rows]
return [r[0] for r in txn.fetchall()]
return self.runInteraction("get_users_in_room", f) return self.runInteraction("get_users_in_room", f)
def get_room_members(self, room_id, membership=None): def get_room_members(self, room_id, membership=None):
@ -159,11 +110,14 @@ class RoomMemberStore(SQLBaseStore):
list of namedtuples representing the members in this room. list of namedtuples representing the members in this room.
""" """
where = {"m.room_id": room_id} def f(txn):
if membership: return self._get_members_events_txn(
where["m.membership"] = membership txn,
room_id,
membership=membership,
)
return self._get_members_by_dict(where) return self.runInteraction("get_room_members", f)
def get_rooms_for_user_where_membership_is(self, user_id, membership_list): def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user """ Get all the rooms for this user where the membership for this user
@ -209,32 +163,55 @@ class RoomMemberStore(SQLBaseStore):
] ]
def get_joined_hosts_for_room(self, room_id): def get_joined_hosts_for_room(self, room_id):
return self._simple_select_onecol( return self.runInteraction(
"room_hosts", "get_joined_hosts_for_room",
{"room_id": room_id}, self._get_joined_hosts_for_room_txn,
"host", room_id,
desc="get_joined_hosts_for_room",
) )
def _get_members_by_dict(self, where_dict): def _get_joined_hosts_for_room_txn(self, txn, room_id):
clause = " AND ".join("%s = ?" % k for k in where_dict.keys()) rows = self._get_members_rows_txn(
vals = where_dict.values() txn,
return self._get_members_query(clause, vals) room_id, membership=Membership.JOIN
)
joined_domains = set(
UserID.from_string(r["user_id"]).domain
for r in rows
)
return joined_domains
def _get_members_query(self, where_clause, where_values): def _get_members_query(self, where_clause, where_values):
return self.runInteraction( return self.runInteraction(
"get_members_query", self._get_members_query_txn, "get_members_query", self._get_members_events_txn,
where_clause, where_values where_clause, where_values
) )
def _get_members_query_txn(self, txn, where_clause, where_values): def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None):
rows = self._get_members_rows_txn(
txn,
room_id, membership, user_id,
)
return self._get_events_txn(txn, [r["event_id"] for r in rows])
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
where_clause = "c.room_id = ?"
where_values = [room_id]
if membership:
where_clause += " AND m.membership = ?"
where_values.append(membership)
if user_id:
where_clause += " AND m.user_id = ?"
where_values.append(user_id)
sql = ( sql = (
"SELECT e.* FROM events as e " "SELECT m.* FROM room_memberships as m"
"INNER JOIN room_memberships as m " " INNER JOIN current_state_events as c"
"ON e.event_id = m.event_id " " ON m.event_id = c.event_id"
"INNER JOIN current_state_events as c " " WHERE %(where)s"
"ON m.event_id = c.event_id "
"WHERE %(where)s "
) % { ) % {
"where": where_clause, "where": where_clause,
} }
@ -242,8 +219,7 @@ class RoomMemberStore(SQLBaseStore):
txn.execute(sql, where_values) txn.execute(sql, where_values)
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
results = self._parse_events_txn(txn, rows) return rows
return results
@cached() @cached()
def get_rooms_for_user(self, user_id): def get_rooms_for_user(self, user_id):

View file

@ -17,26 +17,25 @@ CREATE TABLE IF NOT EXISTS rejections(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
reason TEXT NOT NULL, reason TEXT NOT NULL,
last_check TEXT NOT NULL, last_check TEXT NOT NULL,
CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE UNIQUE (event_id)
); );
-- Push notification endpoints that users have configured -- Push notification endpoints that users have configured
CREATE TABLE IF NOT EXISTS pushers ( CREATE TABLE IF NOT EXISTS pushers (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT NOT NULL, user_name TEXT NOT NULL,
profile_tag varchar(32) NOT NULL, profile_tag VARCHAR(32) NOT NULL,
kind varchar(8) NOT NULL, kind VARCHAR(8) NOT NULL,
app_id varchar(64) NOT NULL, app_id VARCHAR(64) NOT NULL,
app_display_name varchar(64) NOT NULL, app_display_name VARCHAR(64) NOT NULL,
device_display_name varchar(128) NOT NULL, device_display_name VARCHAR(128) NOT NULL,
pushkey blob NOT NULL, pushkey VARBINARY(512) NOT NULL,
ts BIGINT NOT NULL, ts BIGINT UNSIGNED NOT NULL,
lang varchar(8), lang VARCHAR(8),
data blob, data LONGBLOB,
last_token TEXT, last_token TEXT,
last_success BIGINT, last_success BIGINT UNSIGNED,
failing_since BIGINT, failing_since BIGINT UNSIGNED,
FOREIGN KEY(user_name) REFERENCES users(name),
UNIQUE (app_id, pushkey) UNIQUE (app_id, pushkey)
); );
@ -55,13 +54,10 @@ CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name);
CREATE TABLE IF NOT EXISTS user_filters( CREATE TABLE IF NOT EXISTS user_filters(
user_id TEXT, user_id TEXT,
filter_id INTEGER, filter_id BIGINT UNSIGNED,
filter_json TEXT, filter_json LONGBLOB
FOREIGN KEY(user_id) REFERENCES users(id)
); );
CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters( CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters(
user_id, filter_id user_id, filter_id
); );
PRAGMA user_version = 12;

View file

@ -19,16 +19,13 @@ CREATE TABLE IF NOT EXISTS application_services(
token TEXT, token TEXT,
hs_token TEXT, hs_token TEXT,
sender TEXT, sender TEXT,
UNIQUE(token) ON CONFLICT ROLLBACK UNIQUE(token)
); );
CREATE TABLE IF NOT EXISTS application_services_regex( CREATE TABLE IF NOT EXISTS application_services_regex(
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
as_id INTEGER NOT NULL, as_id BIGINT UNSIGNED NOT NULL,
namespace INTEGER, /* enum[room_id|room_alias|user_id] */ namespace INTEGER, /* enum[room_id|room_alias|user_id] */
regex TEXT, regex TEXT,
FOREIGN KEY(as_id) REFERENCES application_services(id) FOREIGN KEY(as_id) REFERENCES application_services(id)
); );

View file

@ -15,16 +15,17 @@
CREATE TABLE IF NOT EXISTS application_services_state( CREATE TABLE IF NOT EXISTS application_services_state(
as_id TEXT PRIMARY KEY, as_id TEXT PRIMARY KEY,
state TEXT, state VARCHAR(5),
last_txn TEXT last_txn INTEGER
); );
CREATE TABLE IF NOT EXISTS application_services_txns( CREATE TABLE IF NOT EXISTS application_services_txns(
as_id TEXT NOT NULL, as_id TEXT NOT NULL,
txn_id INTEGER NOT NULL, txn_id INTEGER NOT NULL,
event_ids TEXT NOT NULL, event_ids TEXT NOT NULL,
UNIQUE(as_id, txn_id) ON CONFLICT ROLLBACK UNIQUE(as_id, txn_id)
); );
CREATE INDEX IF NOT EXISTS application_services_txns_id ON application_services_txns (
as_id
);

View file

@ -0,0 +1,2 @@
CREATE INDEX IF NOT EXISTS presence_list_user_id ON presence_list (user_id);

View file

@ -0,0 +1,4 @@
CREATE INDEX events_order ON events (topological_ordering, stream_ordering);
CREATE INDEX events_order_room ON events (
room_id, topological_ordering, stream_ordering
);

View file

@ -0,0 +1,2 @@
CREATE INDEX IF NOT EXISTS remote_media_cache_thumbnails_media_id
ON remote_media_cache_thumbnails (media_id);

View file

@ -0,0 +1,9 @@
DELETE FROM event_to_state_groups WHERE state_group not in (
SELECT MAX(state_group) FROM event_to_state_groups GROUP BY event_id
);
DELETE FROM event_to_state_groups WHERE rowid not in (
SELECT MIN(rowid) FROM event_to_state_groups GROUP BY event_id
);

View file

@ -0,0 +1,3 @@
CREATE INDEX IF NOT EXISTS room_aliases_id ON room_aliases(room_id);
CREATE INDEX IF NOT EXISTS room_alias_servers_alias ON room_alias_servers(room_alias);

View file

@ -0,0 +1,80 @@
-- We can use SQLite features here, since other db support was only added in v16
--
DELETE FROM current_state_events WHERE rowid not in (
SELECT MIN(rowid) FROM current_state_events GROUP BY event_id
);
DROP INDEX IF EXISTS current_state_events_event_id;
CREATE UNIQUE INDEX current_state_events_event_id ON current_state_events(event_id);
--
DELETE FROM room_memberships WHERE rowid not in (
SELECT MIN(rowid) FROM room_memberships GROUP BY event_id
);
DROP INDEX IF EXISTS room_memberships_event_id;
CREATE UNIQUE INDEX room_memberships_event_id ON room_memberships(event_id);
--
DELETE FROM feedback WHERE rowid not in (
SELECT MIN(rowid) FROM feedback GROUP BY event_id
);
DROP INDEX IF EXISTS feedback_event_id;
CREATE UNIQUE INDEX feedback_event_id ON feedback(event_id);
--
DELETE FROM topics WHERE rowid not in (
SELECT MIN(rowid) FROM topics GROUP BY event_id
);
DROP INDEX IF EXISTS topics_event_id;
CREATE UNIQUE INDEX topics_event_id ON topics(event_id);
--
DELETE FROM room_names WHERE rowid not in (
SELECT MIN(rowid) FROM room_names GROUP BY event_id
);
DROP INDEX IF EXISTS room_names_id;
CREATE UNIQUE INDEX room_names_id ON room_names(event_id);
--
DELETE FROM presence WHERE rowid not in (
SELECT MIN(rowid) FROM presence GROUP BY user_id
);
DROP INDEX IF EXISTS presence_id;
CREATE UNIQUE INDEX presence_id ON presence(user_id);
--
DELETE FROM presence_allow_inbound WHERE rowid not in (
SELECT MIN(rowid) FROM presence_allow_inbound
GROUP BY observed_user_id, observer_user_id
);
DROP INDEX IF EXISTS presence_allow_inbound_observers;
CREATE UNIQUE INDEX presence_allow_inbound_observers ON presence_allow_inbound(
observed_user_id, observer_user_id
);
--
DELETE FROM presence_list WHERE rowid not in (
SELECT MIN(rowid) FROM presence_list
GROUP BY user_id, observed_user_id
);
DROP INDEX IF EXISTS presence_list_observers;
CREATE UNIQUE INDEX presence_list_observers ON presence_list(
user_id, observed_user_id
);
--
DELETE FROM room_aliases WHERE rowid not in (
SELECT MIN(rowid) FROM room_aliases GROUP BY room_alias
);
DROP INDEX IF EXISTS room_aliases_id;
CREATE INDEX room_aliases_id ON room_aliases(room_id);

View file

@ -0,0 +1,56 @@
-- Convert `access_tokens`.user from rowids to user strings.
-- MUST BE DONE BEFORE REMOVING ID COLUMN FROM USERS TABLE BELOW
CREATE TABLE IF NOT EXISTS new_access_tokens(
id BIGINT UNSIGNED PRIMARY KEY,
user_id TEXT NOT NULL,
device_id TEXT,
token TEXT NOT NULL,
last_used BIGINT UNSIGNED,
UNIQUE(token)
);
INSERT INTO new_access_tokens
SELECT a.id, u.name, a.device_id, a.token, a.last_used
FROM access_tokens as a
INNER JOIN users as u ON u.id = a.user_id;
DROP TABLE access_tokens;
ALTER TABLE new_access_tokens RENAME TO access_tokens;
-- Remove ID column from `users` table
CREATE TABLE IF NOT EXISTS new_users(
name TEXT,
password_hash TEXT,
creation_ts BIGINT UNSIGNED,
admin BOOL DEFAULT 0 NOT NULL,
UNIQUE(name)
);
INSERT INTO new_users SELECT name, password_hash, creation_ts, admin FROM users;
DROP TABLE users;
ALTER TABLE new_users RENAME TO users;
-- Remove UNIQUE constraint from `user_ips` table
CREATE TABLE IF NOT EXISTS new_user_ips (
user_id TEXT NOT NULL,
access_token TEXT NOT NULL,
device_id TEXT,
ip TEXT NOT NULL,
user_agent TEXT NOT NULL,
last_seen BIGINT UNSIGNED NOT NULL
);
INSERT INTO new_user_ips
SELECT user, access_token, device_id, ip, user_agent, last_seen FROM user_ips;
DROP TABLE user_ips;
ALTER TABLE new_user_ips RENAME TO user_ips;
CREATE INDEX IF NOT EXISTS user_ips_user ON user_ips(user_id);
CREATE INDEX IF NOT EXISTS user_ips_user_ip ON user_ips(user_id, access_token, ip);

View file

@ -16,52 +16,52 @@
CREATE TABLE IF NOT EXISTS event_forward_extremities( CREATE TABLE IF NOT EXISTS event_forward_extremities(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE UNIQUE (event_id, room_id)
); );
CREATE INDEX IF NOT EXISTS ev_extrem_room ON event_forward_extremities(room_id); CREATE INDEX ev_extrem_room ON event_forward_extremities(room_id);
CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id); CREATE INDEX ev_extrem_id ON event_forward_extremities(event_id);
CREATE TABLE IF NOT EXISTS event_backward_extremities( CREATE TABLE IF NOT EXISTS event_backward_extremities(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE UNIQUE (event_id, room_id)
); );
CREATE INDEX IF NOT EXISTS ev_b_extrem_room ON event_backward_extremities(room_id); CREATE INDEX ev_b_extrem_room ON event_backward_extremities(room_id);
CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id); CREATE INDEX ev_b_extrem_id ON event_backward_extremities(event_id);
CREATE TABLE IF NOT EXISTS event_edges( CREATE TABLE IF NOT EXISTS event_edges(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
prev_event_id TEXT NOT NULL, prev_event_id TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
is_state INTEGER NOT NULL, is_state BOOL NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id, is_state) UNIQUE (event_id, prev_event_id, room_id, is_state)
); );
CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id); CREATE INDEX ev_edges_id ON event_edges(event_id);
CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id); CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id);
CREATE TABLE IF NOT EXISTS room_depth( CREATE TABLE IF NOT EXISTS room_depth(
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
min_depth INTEGER NOT NULL, min_depth INTEGER NOT NULL,
CONSTRAINT uniqueness UNIQUE (room_id) UNIQUE (room_id)
); );
CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id); CREATE INDEX room_depth_room ON room_depth(room_id);
create TABLE IF NOT EXISTS event_destinations( create TABLE IF NOT EXISTS event_destinations(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
destination TEXT NOT NULL, destination TEXT NOT NULL,
delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered delivered_ts BIGINT DEFAULT 0, -- or 0 if not delivered
CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE UNIQUE (event_id, destination)
); );
CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id); CREATE INDEX event_destinations_id ON event_destinations(event_id);
CREATE TABLE IF NOT EXISTS state_forward_extremities( CREATE TABLE IF NOT EXISTS state_forward_extremities(
@ -69,21 +69,21 @@ CREATE TABLE IF NOT EXISTS state_forward_extremities(
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
type TEXT NOT NULL, type TEXT NOT NULL,
state_key TEXT NOT NULL, state_key TEXT NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE UNIQUE (event_id, room_id)
); );
CREATE INDEX IF NOT EXISTS st_extrem_keys ON state_forward_extremities( CREATE INDEX st_extrem_keys ON state_forward_extremities(
room_id, type, state_key room_id, type, state_key
); );
CREATE INDEX IF NOT EXISTS st_extrem_id ON state_forward_extremities(event_id); CREATE INDEX st_extrem_id ON state_forward_extremities(event_id);
CREATE TABLE IF NOT EXISTS event_auth( CREATE TABLE IF NOT EXISTS event_auth(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
auth_id TEXT NOT NULL, auth_id TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, auth_id, room_id) UNIQUE (event_id, auth_id, room_id)
); );
CREATE INDEX IF NOT EXISTS evauth_edges_id ON event_auth(event_id); CREATE INDEX evauth_edges_id ON event_auth(event_id);
CREATE INDEX IF NOT EXISTS evauth_edges_auth_id ON event_auth(auth_id); CREATE INDEX evauth_edges_auth_id ON event_auth(auth_id);

View file

@ -16,50 +16,40 @@
CREATE TABLE IF NOT EXISTS event_content_hashes ( CREATE TABLE IF NOT EXISTS event_content_hashes (
event_id TEXT, event_id TEXT,
algorithm TEXT, algorithm TEXT,
hash BLOB, hash bytea,
CONSTRAINT uniqueness UNIQUE (event_id, algorithm) UNIQUE (event_id, algorithm)
); );
CREATE INDEX IF NOT EXISTS event_content_hashes_id ON event_content_hashes( CREATE INDEX event_content_hashes_id ON event_content_hashes(event_id);
event_id
);
CREATE TABLE IF NOT EXISTS event_reference_hashes ( CREATE TABLE IF NOT EXISTS event_reference_hashes (
event_id TEXT, event_id TEXT,
algorithm TEXT, algorithm TEXT,
hash BLOB, hash bytea,
CONSTRAINT uniqueness UNIQUE (event_id, algorithm) UNIQUE (event_id, algorithm)
); );
CREATE INDEX IF NOT EXISTS event_reference_hashes_id ON event_reference_hashes ( CREATE INDEX event_reference_hashes_id ON event_reference_hashes(event_id);
event_id
);
CREATE TABLE IF NOT EXISTS event_signatures ( CREATE TABLE IF NOT EXISTS event_signatures (
event_id TEXT, event_id TEXT,
signature_name TEXT, signature_name TEXT,
key_id TEXT, key_id TEXT,
signature BLOB, signature bytea,
CONSTRAINT uniqueness UNIQUE (event_id, signature_name, key_id) UNIQUE (event_id, signature_name, key_id)
); );
CREATE INDEX IF NOT EXISTS event_signatures_id ON event_signatures ( CREATE INDEX event_signatures_id ON event_signatures(event_id);
event_id
);
CREATE TABLE IF NOT EXISTS event_edge_hashes( CREATE TABLE IF NOT EXISTS event_edge_hashes(
event_id TEXT, event_id TEXT,
prev_event_id TEXT, prev_event_id TEXT,
algorithm TEXT, algorithm TEXT,
hash BLOB, hash bytea,
CONSTRAINT uniqueness UNIQUE ( UNIQUE (event_id, prev_event_id, algorithm)
event_id, prev_event_id, algorithm
)
); );
CREATE INDEX IF NOT EXISTS event_edge_hashes_id ON event_edge_hashes( CREATE INDEX event_edge_hashes_id ON event_edge_hashes(event_id);
event_id
);

View file

@ -15,7 +15,7 @@
CREATE TABLE IF NOT EXISTS events( CREATE TABLE IF NOT EXISTS events(
stream_ordering INTEGER PRIMARY KEY AUTOINCREMENT, stream_ordering INTEGER PRIMARY KEY AUTOINCREMENT,
topological_ordering INTEGER NOT NULL, topological_ordering BIGINT NOT NULL,
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
type TEXT NOT NULL, type TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
@ -23,26 +23,24 @@ CREATE TABLE IF NOT EXISTS events(
unrecognized_keys TEXT, unrecognized_keys TEXT,
processed BOOL NOT NULL, processed BOOL NOT NULL,
outlier BOOL NOT NULL, outlier BOOL NOT NULL,
depth INTEGER DEFAULT 0 NOT NULL, depth BIGINT DEFAULT 0 NOT NULL,
CONSTRAINT ev_uniq UNIQUE (event_id) UNIQUE (event_id)
); );
CREATE INDEX IF NOT EXISTS events_event_id ON events (event_id); CREATE INDEX events_stream_ordering ON events (stream_ordering);
CREATE INDEX IF NOT EXISTS events_stream_ordering ON events (stream_ordering); CREATE INDEX events_topological_ordering ON events (topological_ordering);
CREATE INDEX IF NOT EXISTS events_topological_ordering ON events (topological_ordering); CREATE INDEX events_room_id ON events (room_id);
CREATE INDEX IF NOT EXISTS events_room_id ON events (room_id);
CREATE TABLE IF NOT EXISTS event_json( CREATE TABLE IF NOT EXISTS event_json(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
internal_metadata NOT NULL, internal_metadata TEXT NOT NULL,
json BLOB NOT NULL, json TEXT NOT NULL,
CONSTRAINT ev_j_uniq UNIQUE (event_id) UNIQUE (event_id)
); );
CREATE INDEX IF NOT EXISTS event_json_id ON event_json(event_id); CREATE INDEX event_json_room_id ON event_json(room_id);
CREATE INDEX IF NOT EXISTS event_json_room_id ON event_json(room_id);
CREATE TABLE IF NOT EXISTS state_events( CREATE TABLE IF NOT EXISTS state_events(
@ -50,13 +48,13 @@ CREATE TABLE IF NOT EXISTS state_events(
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
type TEXT NOT NULL, type TEXT NOT NULL,
state_key TEXT NOT NULL, state_key TEXT NOT NULL,
prev_state TEXT prev_state TEXT,
UNIQUE (event_id)
); );
CREATE UNIQUE INDEX IF NOT EXISTS state_events_event_id ON state_events (event_id); CREATE INDEX state_events_room_id ON state_events (room_id);
CREATE INDEX IF NOT EXISTS state_events_room_id ON state_events (room_id); CREATE INDEX state_events_type ON state_events (type);
CREATE INDEX IF NOT EXISTS state_events_type ON state_events (type); CREATE INDEX state_events_state_key ON state_events (state_key);
CREATE INDEX IF NOT EXISTS state_events_state_key ON state_events (state_key);
CREATE TABLE IF NOT EXISTS current_state_events( CREATE TABLE IF NOT EXISTS current_state_events(
@ -64,13 +62,13 @@ CREATE TABLE IF NOT EXISTS current_state_events(
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
type TEXT NOT NULL, type TEXT NOT NULL,
state_key TEXT NOT NULL, state_key TEXT NOT NULL,
CONSTRAINT curr_uniq UNIQUE (room_id, type, state_key) ON CONFLICT REPLACE UNIQUE (room_id, type, state_key)
); );
CREATE INDEX IF NOT EXISTS curr_events_event_id ON current_state_events (event_id); CREATE INDEX curr_events_event_id ON current_state_events (event_id);
CREATE INDEX IF NOT EXISTS current_state_events_room_id ON current_state_events (room_id); CREATE INDEX current_state_events_room_id ON current_state_events (room_id);
CREATE INDEX IF NOT EXISTS current_state_events_type ON current_state_events (type); CREATE INDEX current_state_events_type ON current_state_events (type);
CREATE INDEX IF NOT EXISTS current_state_events_state_key ON current_state_events (state_key); CREATE INDEX current_state_events_state_key ON current_state_events (state_key);
CREATE TABLE IF NOT EXISTS room_memberships( CREATE TABLE IF NOT EXISTS room_memberships(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
@ -80,9 +78,9 @@ CREATE TABLE IF NOT EXISTS room_memberships(
membership TEXT NOT NULL membership TEXT NOT NULL
); );
CREATE INDEX IF NOT EXISTS room_memberships_event_id ON room_memberships (event_id); CREATE INDEX room_memberships_event_id ON room_memberships (event_id);
CREATE INDEX IF NOT EXISTS room_memberships_room_id ON room_memberships (room_id); CREATE INDEX room_memberships_room_id ON room_memberships (room_id);
CREATE INDEX IF NOT EXISTS room_memberships_user_id ON room_memberships (user_id); CREATE INDEX room_memberships_user_id ON room_memberships (user_id);
CREATE TABLE IF NOT EXISTS feedback( CREATE TABLE IF NOT EXISTS feedback(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
@ -98,8 +96,8 @@ CREATE TABLE IF NOT EXISTS topics(
topic TEXT NOT NULL topic TEXT NOT NULL
); );
CREATE INDEX IF NOT EXISTS topics_event_id ON topics(event_id); CREATE INDEX topics_event_id ON topics(event_id);
CREATE INDEX IF NOT EXISTS topics_room_id ON topics(room_id); CREATE INDEX topics_room_id ON topics(room_id);
CREATE TABLE IF NOT EXISTS room_names( CREATE TABLE IF NOT EXISTS room_names(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
@ -107,19 +105,19 @@ CREATE TABLE IF NOT EXISTS room_names(
name TEXT NOT NULL name TEXT NOT NULL
); );
CREATE INDEX IF NOT EXISTS room_names_event_id ON room_names(event_id); CREATE INDEX room_names_event_id ON room_names(event_id);
CREATE INDEX IF NOT EXISTS room_names_room_id ON room_names(room_id); CREATE INDEX room_names_room_id ON room_names(room_id);
CREATE TABLE IF NOT EXISTS rooms( CREATE TABLE IF NOT EXISTS rooms(
room_id TEXT PRIMARY KEY NOT NULL, room_id TEXT PRIMARY KEY NOT NULL,
is_public INTEGER, is_public BOOL,
creator TEXT creator TEXT
); );
CREATE TABLE IF NOT EXISTS room_hosts( CREATE TABLE IF NOT EXISTS room_hosts(
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
host TEXT NOT NULL, host TEXT NOT NULL,
CONSTRAINT room_hosts_uniq UNIQUE (room_id, host) ON CONFLICT IGNORE UNIQUE (room_id, host)
); );
CREATE INDEX IF NOT EXISTS room_hosts_room_id ON room_hosts (room_id); CREATE INDEX room_hosts_room_id ON room_hosts (room_id);

View file

@ -16,16 +16,16 @@ CREATE TABLE IF NOT EXISTS server_tls_certificates(
server_name TEXT, -- Server name. server_name TEXT, -- Server name.
fingerprint TEXT, -- Certificate fingerprint. fingerprint TEXT, -- Certificate fingerprint.
from_server TEXT, -- Which key server the certificate was fetched from. from_server TEXT, -- Which key server the certificate was fetched from.
ts_added_ms INTEGER, -- When the certifcate was added. ts_added_ms BIGINT, -- When the certifcate was added.
tls_certificate BLOB, -- DER encoded x509 certificate. tls_certificate bytea, -- DER encoded x509 certificate.
CONSTRAINT uniqueness UNIQUE (server_name, fingerprint) UNIQUE (server_name, fingerprint)
); );
CREATE TABLE IF NOT EXISTS server_signature_keys( CREATE TABLE IF NOT EXISTS server_signature_keys(
server_name TEXT, -- Server name. server_name TEXT, -- Server name.
key_id TEXT, -- Key version. key_id TEXT, -- Key version.
from_server TEXT, -- Which key server the key was fetched form. from_server TEXT, -- Which key server the key was fetched form.
ts_added_ms INTEGER, -- When the key was added. ts_added_ms BIGINT, -- When the key was added.
verify_key BLOB, -- NACL verification key. verify_key bytea, -- NACL verification key.
CONSTRAINT uniqueness UNIQUE (server_name, key_id) UNIQUE (server_name, key_id)
); );

View file

@ -17,10 +17,10 @@ CREATE TABLE IF NOT EXISTS local_media_repository (
media_id TEXT, -- The id used to refer to the media. media_id TEXT, -- The id used to refer to the media.
media_type TEXT, -- The MIME-type of the media. media_type TEXT, -- The MIME-type of the media.
media_length INTEGER, -- Length of the media in bytes. media_length INTEGER, -- Length of the media in bytes.
created_ts INTEGER, -- When the content was uploaded in ms. created_ts BIGINT, -- When the content was uploaded in ms.
upload_name TEXT, -- The name the media was uploaded with. upload_name TEXT, -- The name the media was uploaded with.
user_id TEXT, -- The user who uploaded the file. user_id TEXT, -- The user who uploaded the file.
CONSTRAINT uniqueness UNIQUE (media_id) UNIQUE (media_id)
); );
CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails ( CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails (
@ -30,23 +30,23 @@ CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails (
thumbnail_type TEXT, -- The MIME-type of the thumbnail. thumbnail_type TEXT, -- The MIME-type of the thumbnail.
thumbnail_method TEXT, -- The method used to make the thumbnail. thumbnail_method TEXT, -- The method used to make the thumbnail.
thumbnail_length INTEGER, -- The length of the thumbnail in bytes. thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
CONSTRAINT uniqueness UNIQUE ( UNIQUE (
media_id, thumbnail_width, thumbnail_height, thumbnail_type media_id, thumbnail_width, thumbnail_height, thumbnail_type
) )
); );
CREATE INDEX IF NOT EXISTS local_media_repository_thumbnails_media_id CREATE INDEX local_media_repository_thumbnails_media_id
ON local_media_repository_thumbnails (media_id); ON local_media_repository_thumbnails (media_id);
CREATE TABLE IF NOT EXISTS remote_media_cache ( CREATE TABLE IF NOT EXISTS remote_media_cache (
media_origin TEXT, -- The remote HS the media came from. media_origin TEXT, -- The remote HS the media came from.
media_id TEXT, -- The id used to refer to the media on that server. media_id TEXT, -- The id used to refer to the media on that server.
media_type TEXT, -- The MIME-type of the media. media_type TEXT, -- The MIME-type of the media.
created_ts INTEGER, -- When the content was uploaded in ms. created_ts BIGINT, -- When the content was uploaded in ms.
upload_name TEXT, -- The name the media was uploaded with. upload_name TEXT, -- The name the media was uploaded with.
media_length INTEGER, -- Length of the media in bytes. media_length INTEGER, -- Length of the media in bytes.
filesystem_id TEXT, -- The name used to store the media on disk. filesystem_id TEXT, -- The name used to store the media on disk.
CONSTRAINT uniqueness UNIQUE (media_origin, media_id) UNIQUE (media_origin, media_id)
); );
CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails ( CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails (
@ -58,11 +58,8 @@ CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails (
thumbnail_type TEXT, -- The MIME-type of the thumbnail. thumbnail_type TEXT, -- The MIME-type of the thumbnail.
thumbnail_length INTEGER, -- The length of the thumbnail in bytes. thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
filesystem_id TEXT, -- The name used to store the media on disk. filesystem_id TEXT, -- The name used to store the media on disk.
CONSTRAINT uniqueness UNIQUE ( UNIQUE (
media_origin, media_id, thumbnail_width, thumbnail_height, media_origin, media_id, thumbnail_width, thumbnail_height,
thumbnail_type, thumbnail_type thumbnail_type
) )
); );
CREATE INDEX IF NOT EXISTS remote_media_cache_thumbnails_media_id
ON local_media_repository_thumbnails (media_id);

View file

@ -13,26 +13,23 @@
* limitations under the License. * limitations under the License.
*/ */
CREATE TABLE IF NOT EXISTS presence( CREATE TABLE IF NOT EXISTS presence(
user_id INTEGER NOT NULL, user_id TEXT NOT NULL,
state INTEGER, state VARCHAR(20),
status_msg TEXT, status_msg TEXT,
mtime INTEGER, -- miliseconds since last state change mtime BIGINT -- miliseconds since last state change
FOREIGN KEY(user_id) REFERENCES users(id)
); );
-- For each of /my/ users which possibly-remote users are allowed to see their -- For each of /my/ users which possibly-remote users are allowed to see their
-- presence state -- presence state
CREATE TABLE IF NOT EXISTS presence_allow_inbound( CREATE TABLE IF NOT EXISTS presence_allow_inbound(
observed_user_id INTEGER NOT NULL, observed_user_id TEXT NOT NULL,
observer_user_id TEXT, -- a UserID, observer_user_id TEXT NOT NULL -- a UserID,
FOREIGN KEY(observed_user_id) REFERENCES users(id)
); );
-- For each of /my/ users (watcher), which possibly-remote users are they -- For each of /my/ users (watcher), which possibly-remote users are they
-- watching? -- watching?
CREATE TABLE IF NOT EXISTS presence_list( CREATE TABLE IF NOT EXISTS presence_list(
user_id INTEGER NOT NULL, user_id TEXT NOT NULL,
observed_user_id TEXT, -- a UserID, observed_user_id TEXT NOT NULL, -- a UserID,
accepted BOOLEAN, accepted BOOLEAN NOT NULL
FOREIGN KEY(user_id) REFERENCES users(id)
); );

View file

@ -13,8 +13,7 @@
* limitations under the License. * limitations under the License.
*/ */
CREATE TABLE IF NOT EXISTS profiles( CREATE TABLE IF NOT EXISTS profiles(
user_id INTEGER NOT NULL, user_id TEXT NOT NULL,
displayname TEXT, displayname TEXT,
avatar_url TEXT, avatar_url TEXT
FOREIGN KEY(user_id) REFERENCES users(id)
); );

View file

@ -15,8 +15,8 @@
CREATE TABLE IF NOT EXISTS redactions ( CREATE TABLE IF NOT EXISTS redactions (
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
redacts TEXT NOT NULL, redacts TEXT NOT NULL,
CONSTRAINT ev_uniq UNIQUE (event_id) UNIQUE (event_id)
); );
CREATE INDEX IF NOT EXISTS redactions_event_id ON redactions (event_id); CREATE INDEX redactions_event_id ON redactions (event_id);
CREATE INDEX IF NOT EXISTS redactions_redacts ON redactions (redacts); CREATE INDEX redactions_redacts ON redactions (redacts);

View file

@ -22,6 +22,3 @@ CREATE TABLE IF NOT EXISTS room_alias_servers(
room_alias TEXT NOT NULL, room_alias TEXT NOT NULL,
server TEXT NOT NULL server TEXT NOT NULL
); );

View file

@ -30,18 +30,11 @@ CREATE TABLE IF NOT EXISTS state_groups_state(
CREATE TABLE IF NOT EXISTS event_to_state_groups( CREATE TABLE IF NOT EXISTS event_to_state_groups(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
state_group INTEGER NOT NULL, state_group INTEGER NOT NULL,
CONSTRAINT event_to_state_groups_uniq UNIQUE (event_id) UNIQUE (event_id)
); );
CREATE INDEX IF NOT EXISTS state_groups_id ON state_groups(id); CREATE INDEX state_groups_id ON state_groups(id);
CREATE INDEX IF NOT EXISTS state_groups_state_id ON state_groups_state( CREATE INDEX state_groups_state_id ON state_groups_state(state_group);
state_group CREATE INDEX state_groups_state_tuple ON state_groups_state(room_id, type, state_key);
); CREATE INDEX event_to_state_groups_id ON event_to_state_groups(event_id);
CREATE INDEX IF NOT EXISTS state_groups_state_tuple ON state_groups_state(
room_id, type, state_key
);
CREATE INDEX IF NOT EXISTS event_to_state_groups_id ON event_to_state_groups(
event_id
);

View file

@ -14,17 +14,16 @@
*/ */
-- Stores what transaction ids we have received and what our response was -- Stores what transaction ids we have received and what our response was
CREATE TABLE IF NOT EXISTS received_transactions( CREATE TABLE IF NOT EXISTS received_transactions(
transaction_id TEXT, transaction_id TEXT,
origin TEXT, origin TEXT,
ts INTEGER, ts BIGINT,
response_code INTEGER, response_code INTEGER,
response_json TEXT, response_json bytea,
has_been_referenced BOOL default 0, -- Whether thishas been referenced by a prev_tx has_been_referenced SMALLINT DEFAULT 0, -- Whether thishas been referenced by a prev_tx
CONSTRAINT uniquesss UNIQUE (transaction_id, origin) ON CONFLICT REPLACE UNIQUE (transaction_id, origin)
); );
CREATE UNIQUE INDEX IF NOT EXISTS transactions_txid ON received_transactions(transaction_id, origin); CREATE INDEX transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0;
CREATE INDEX IF NOT EXISTS transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0;
-- Stores what transactions we've sent, what their response was (if we got one) and whether we have -- Stores what transactions we've sent, what their response was (if we got one) and whether we have
@ -35,17 +34,14 @@ CREATE TABLE IF NOT EXISTS sent_transactions(
destination TEXT, destination TEXT,
response_code INTEGER DEFAULT 0, response_code INTEGER DEFAULT 0,
response_json TEXT, response_json TEXT,
ts INTEGER ts BIGINT
); );
CREATE INDEX IF NOT EXISTS sent_transaction_dest ON sent_transactions(destination); CREATE INDEX sent_transaction_dest ON sent_transactions(destination);
CREATE INDEX IF NOT EXISTS sent_transaction_dest_referenced ON sent_transactions( CREATE INDEX sent_transaction_txn_id ON sent_transactions(transaction_id);
destination
);
CREATE INDEX IF NOT EXISTS sent_transaction_txn_id ON sent_transactions(transaction_id);
-- So that we can do an efficient look up of all transactions that have yet to be successfully -- So that we can do an efficient look up of all transactions that have yet to be successfully
-- sent. -- sent.
CREATE INDEX IF NOT EXISTS sent_transaction_sent ON sent_transactions(response_code); CREATE INDEX sent_transaction_sent ON sent_transactions(response_code);
-- For sent transactions only. -- For sent transactions only.
@ -56,13 +52,12 @@ CREATE TABLE IF NOT EXISTS transaction_id_to_pdu(
pdu_origin TEXT pdu_origin TEXT
); );
CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_tx ON transaction_id_to_pdu(transaction_id, destination); CREATE INDEX transaction_id_to_pdu_tx ON transaction_id_to_pdu(transaction_id, destination);
CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination); CREATE INDEX transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination);
CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_index ON transaction_id_to_pdu(transaction_id, destination);
-- To track destination health -- To track destination health
CREATE TABLE IF NOT EXISTS destinations( CREATE TABLE IF NOT EXISTS destinations(
destination TEXT PRIMARY KEY, destination TEXT PRIMARY KEY,
retry_last_ts INTEGER, retry_last_ts BIGINT,
retry_interval INTEGER retry_interval INTEGER
); );

View file

@ -16,19 +16,18 @@ CREATE TABLE IF NOT EXISTS users(
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT, name TEXT,
password_hash TEXT, password_hash TEXT,
creation_ts INTEGER, creation_ts BIGINT,
admin BOOL DEFAULT 0 NOT NULL, admin SMALLINT DEFAULT 0 NOT NULL,
UNIQUE(name) ON CONFLICT ROLLBACK UNIQUE(name)
); );
CREATE TABLE IF NOT EXISTS access_tokens( CREATE TABLE IF NOT EXISTS access_tokens(
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL, user_id TEXT NOT NULL,
device_id TEXT, device_id TEXT,
token TEXT NOT NULL, token TEXT NOT NULL,
last_used INTEGER, last_used BIGINT,
FOREIGN KEY(user_id) REFERENCES users(id), UNIQUE(token)
UNIQUE(token) ON CONFLICT ROLLBACK
); );
CREATE TABLE IF NOT EXISTS user_ips ( CREATE TABLE IF NOT EXISTS user_ips (
@ -37,9 +36,8 @@ CREATE TABLE IF NOT EXISTS user_ips (
device_id TEXT, device_id TEXT,
ip TEXT NOT NULL, ip TEXT NOT NULL,
user_agent TEXT NOT NULL, user_agent TEXT NOT NULL,
last_seen INTEGER NOT NULL, last_seen BIGINT NOT NULL,
CONSTRAINT user_ip UNIQUE (user, access_token, ip, user_agent) ON CONFLICT REPLACE UNIQUE (user, access_token, ip, user_agent)
); );
CREATE INDEX IF NOT EXISTS user_ips_user ON user_ips(user); CREATE INDEX user_ips_user ON user_ips(user);

View file

@ -0,0 +1,48 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS application_services(
id BIGINT PRIMARY KEY,
url TEXT,
token TEXT,
hs_token TEXT,
sender TEXT,
UNIQUE(token)
);
CREATE TABLE IF NOT EXISTS application_services_regex(
id BIGINT PRIMARY KEY,
as_id BIGINT NOT NULL,
namespace INTEGER, /* enum[room_id|room_alias|user_id] */
regex TEXT,
FOREIGN KEY(as_id) REFERENCES application_services(id)
);
CREATE TABLE IF NOT EXISTS application_services_state(
as_id TEXT PRIMARY KEY,
state VARCHAR(5),
last_txn INTEGER
);
CREATE TABLE IF NOT EXISTS application_services_txns(
as_id TEXT NOT NULL,
txn_id INTEGER NOT NULL,
event_ids TEXT NOT NULL,
UNIQUE(as_id, txn_id)
);
CREATE INDEX application_services_txns_id ON application_services_txns (
as_id
);

View file

@ -0,0 +1,89 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS event_forward_extremities(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
UNIQUE (event_id, room_id)
);
CREATE INDEX ev_extrem_room ON event_forward_extremities(room_id);
CREATE INDEX ev_extrem_id ON event_forward_extremities(event_id);
CREATE TABLE IF NOT EXISTS event_backward_extremities(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
UNIQUE (event_id, room_id)
);
CREATE INDEX ev_b_extrem_room ON event_backward_extremities(room_id);
CREATE INDEX ev_b_extrem_id ON event_backward_extremities(event_id);
CREATE TABLE IF NOT EXISTS event_edges(
event_id TEXT NOT NULL,
prev_event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
is_state BOOL NOT NULL,
UNIQUE (event_id, prev_event_id, room_id, is_state)
);
CREATE INDEX ev_edges_id ON event_edges(event_id);
CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id);
CREATE TABLE IF NOT EXISTS room_depth(
room_id TEXT NOT NULL,
min_depth INTEGER NOT NULL,
UNIQUE (room_id)
);
CREATE INDEX room_depth_room ON room_depth(room_id);
create TABLE IF NOT EXISTS event_destinations(
event_id TEXT NOT NULL,
destination TEXT NOT NULL,
delivered_ts BIGINT DEFAULT 0, -- or 0 if not delivered
UNIQUE (event_id, destination)
);
CREATE INDEX event_destinations_id ON event_destinations(event_id);
CREATE TABLE IF NOT EXISTS state_forward_extremities(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
UNIQUE (event_id, room_id)
);
CREATE INDEX st_extrem_keys ON state_forward_extremities(
room_id, type, state_key
);
CREATE INDEX st_extrem_id ON state_forward_extremities(event_id);
CREATE TABLE IF NOT EXISTS event_auth(
event_id TEXT NOT NULL,
auth_id TEXT NOT NULL,
room_id TEXT NOT NULL,
UNIQUE (event_id, auth_id, room_id)
);
CREATE INDEX evauth_edges_id ON event_auth(event_id);
CREATE INDEX evauth_edges_auth_id ON event_auth(auth_id);

View file

@ -0,0 +1,55 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS event_content_hashes (
event_id TEXT,
algorithm TEXT,
hash bytea,
UNIQUE (event_id, algorithm)
);
CREATE INDEX event_content_hashes_id ON event_content_hashes(event_id);
CREATE TABLE IF NOT EXISTS event_reference_hashes (
event_id TEXT,
algorithm TEXT,
hash bytea,
UNIQUE (event_id, algorithm)
);
CREATE INDEX event_reference_hashes_id ON event_reference_hashes(event_id);
CREATE TABLE IF NOT EXISTS event_signatures (
event_id TEXT,
signature_name TEXT,
key_id TEXT,
signature bytea,
UNIQUE (event_id, signature_name, key_id)
);
CREATE INDEX event_signatures_id ON event_signatures(event_id);
CREATE TABLE IF NOT EXISTS event_edge_hashes(
event_id TEXT,
prev_event_id TEXT,
algorithm TEXT,
hash bytea,
UNIQUE (event_id, prev_event_id, algorithm)
);
CREATE INDEX event_edge_hashes_id ON event_edge_hashes(event_id);

View file

@ -0,0 +1,128 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS events(
stream_ordering INTEGER PRIMARY KEY,
topological_ordering BIGINT NOT NULL,
event_id TEXT NOT NULL,
type TEXT NOT NULL,
room_id TEXT NOT NULL,
content TEXT NOT NULL,
unrecognized_keys TEXT,
processed BOOL NOT NULL,
outlier BOOL NOT NULL,
depth BIGINT DEFAULT 0 NOT NULL,
UNIQUE (event_id)
);
CREATE INDEX events_stream_ordering ON events (stream_ordering);
CREATE INDEX events_topological_ordering ON events (topological_ordering);
CREATE INDEX events_order ON events (topological_ordering, stream_ordering);
CREATE INDEX events_room_id ON events (room_id);
CREATE INDEX events_order_room ON events (
room_id, topological_ordering, stream_ordering
);
CREATE TABLE IF NOT EXISTS event_json(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
internal_metadata TEXT NOT NULL,
json TEXT NOT NULL,
UNIQUE (event_id)
);
CREATE INDEX event_json_room_id ON event_json(room_id);
CREATE TABLE IF NOT EXISTS state_events(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
prev_state TEXT,
UNIQUE (event_id)
);
CREATE INDEX state_events_room_id ON state_events (room_id);
CREATE INDEX state_events_type ON state_events (type);
CREATE INDEX state_events_state_key ON state_events (state_key);
CREATE TABLE IF NOT EXISTS current_state_events(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
UNIQUE (event_id),
UNIQUE (room_id, type, state_key)
);
CREATE INDEX current_state_events_room_id ON current_state_events (room_id);
CREATE INDEX current_state_events_type ON current_state_events (type);
CREATE INDEX current_state_events_state_key ON current_state_events (state_key);
CREATE TABLE IF NOT EXISTS room_memberships(
event_id TEXT NOT NULL,
user_id TEXT NOT NULL,
sender TEXT NOT NULL,
room_id TEXT NOT NULL,
membership TEXT NOT NULL,
UNIQUE (event_id)
);
CREATE INDEX room_memberships_room_id ON room_memberships (room_id);
CREATE INDEX room_memberships_user_id ON room_memberships (user_id);
CREATE TABLE IF NOT EXISTS feedback(
event_id TEXT NOT NULL,
feedback_type TEXT,
target_event_id TEXT,
sender TEXT,
room_id TEXT,
UNIQUE (event_id)
);
CREATE TABLE IF NOT EXISTS topics(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
topic TEXT NOT NULL,
UNIQUE (event_id)
);
CREATE INDEX topics_room_id ON topics(room_id);
CREATE TABLE IF NOT EXISTS room_names(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
name TEXT NOT NULL,
UNIQUE (event_id)
);
CREATE INDEX room_names_room_id ON room_names(room_id);
CREATE TABLE IF NOT EXISTS rooms(
room_id TEXT PRIMARY KEY NOT NULL,
is_public BOOL,
creator TEXT
);
CREATE TABLE IF NOT EXISTS room_hosts(
room_id TEXT NOT NULL,
host TEXT NOT NULL,
UNIQUE (room_id, host)
);
CREATE INDEX room_hosts_room_id ON room_hosts (room_id);

View file

@ -0,0 +1,31 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS server_tls_certificates(
server_name TEXT, -- Server name.
fingerprint TEXT, -- Certificate fingerprint.
from_server TEXT, -- Which key server the certificate was fetched from.
ts_added_ms BIGINT, -- When the certifcate was added.
tls_certificate bytea, -- DER encoded x509 certificate.
UNIQUE (server_name, fingerprint)
);
CREATE TABLE IF NOT EXISTS server_signature_keys(
server_name TEXT, -- Server name.
key_id TEXT, -- Key version.
from_server TEXT, -- Which key server the key was fetched form.
ts_added_ms BIGINT, -- When the key was added.
verify_key bytea, -- NACL verification key.
UNIQUE (server_name, key_id)
);

View file

@ -0,0 +1,68 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS local_media_repository (
media_id TEXT, -- The id used to refer to the media.
media_type TEXT, -- The MIME-type of the media.
media_length INTEGER, -- Length of the media in bytes.
created_ts BIGINT, -- When the content was uploaded in ms.
upload_name TEXT, -- The name the media was uploaded with.
user_id TEXT, -- The user who uploaded the file.
UNIQUE (media_id)
);
CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails (
media_id TEXT, -- The id used to refer to the media.
thumbnail_width INTEGER, -- The width of the thumbnail in pixels.
thumbnail_height INTEGER, -- The height of the thumbnail in pixels.
thumbnail_type TEXT, -- The MIME-type of the thumbnail.
thumbnail_method TEXT, -- The method used to make the thumbnail.
thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
UNIQUE (
media_id, thumbnail_width, thumbnail_height, thumbnail_type
)
);
CREATE INDEX local_media_repository_thumbnails_media_id
ON local_media_repository_thumbnails (media_id);
CREATE TABLE IF NOT EXISTS remote_media_cache (
media_origin TEXT, -- The remote HS the media came from.
media_id TEXT, -- The id used to refer to the media on that server.
media_type TEXT, -- The MIME-type of the media.
created_ts BIGINT, -- When the content was uploaded in ms.
upload_name TEXT, -- The name the media was uploaded with.
media_length INTEGER, -- Length of the media in bytes.
filesystem_id TEXT, -- The name used to store the media on disk.
UNIQUE (media_origin, media_id)
);
CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails (
media_origin TEXT, -- The remote HS the media came from.
media_id TEXT, -- The id used to refer to the media.
thumbnail_width INTEGER, -- The width of the thumbnail in pixels.
thumbnail_height INTEGER, -- The height of the thumbnail in pixels.
thumbnail_method TEXT, -- The method used to make the thumbnail
thumbnail_type TEXT, -- The MIME-type of the thumbnail.
thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
filesystem_id TEXT, -- The name used to store the media on disk.
UNIQUE (
media_origin, media_id, thumbnail_width, thumbnail_height,
thumbnail_type
)
);
CREATE INDEX remote_media_cache_thumbnails_media_id
ON remote_media_cache_thumbnails (media_id);

View file

@ -0,0 +1,40 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS presence(
user_id TEXT NOT NULL,
state VARCHAR(20),
status_msg TEXT,
mtime BIGINT, -- miliseconds since last state change
UNIQUE (user_id)
);
-- For each of /my/ users which possibly-remote users are allowed to see their
-- presence state
CREATE TABLE IF NOT EXISTS presence_allow_inbound(
observed_user_id TEXT NOT NULL,
observer_user_id TEXT NOT NULL, -- a UserID,
UNIQUE (observed_user_id, observer_user_id)
);
-- For each of /my/ users (watcher), which possibly-remote users are they
-- watching?
CREATE TABLE IF NOT EXISTS presence_list(
user_id TEXT NOT NULL,
observed_user_id TEXT NOT NULL, -- a UserID,
accepted BOOLEAN NOT NULL,
UNIQUE (user_id, observed_user_id)
);
CREATE INDEX presence_list_user_id ON presence_list (user_id);

View file

@ -0,0 +1,20 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS profiles(
user_id TEXT NOT NULL,
displayname TEXT,
avatar_url TEXT,
UNIQUE(user_id)
);

View file

@ -0,0 +1,73 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS rejections(
event_id TEXT NOT NULL,
reason TEXT NOT NULL,
last_check TEXT NOT NULL,
UNIQUE (event_id)
);
-- Push notification endpoints that users have configured
CREATE TABLE IF NOT EXISTS pushers (
id BIGINT PRIMARY KEY,
user_name TEXT NOT NULL,
profile_tag VARCHAR(32) NOT NULL,
kind VARCHAR(8) NOT NULL,
app_id VARCHAR(64) NOT NULL,
app_display_name VARCHAR(64) NOT NULL,
device_display_name VARCHAR(128) NOT NULL,
pushkey bytea NOT NULL,
ts BIGINT NOT NULL,
lang VARCHAR(8),
data bytea,
last_token TEXT,
last_success BIGINT,
failing_since BIGINT,
UNIQUE (app_id, pushkey)
);
CREATE TABLE IF NOT EXISTS push_rules (
id BIGINT PRIMARY KEY,
user_name TEXT NOT NULL,
rule_id TEXT NOT NULL,
priority_class SMALLINT NOT NULL,
priority INTEGER NOT NULL DEFAULT 0,
conditions TEXT NOT NULL,
actions TEXT NOT NULL,
UNIQUE(user_name, rule_id)
);
CREATE INDEX push_rules_user_name on push_rules (user_name);
CREATE TABLE IF NOT EXISTS user_filters(
user_id TEXT,
filter_id BIGINT,
filter_json bytea
);
CREATE INDEX user_filters_by_user_id_filter_id ON user_filters(
user_id, filter_id
);
CREATE TABLE IF NOT EXISTS push_rules_enable (
id BIGINT PRIMARY KEY,
user_name TEXT NOT NULL,
rule_id TEXT NOT NULL,
enabled SMALLINT,
UNIQUE(user_name, rule_id)
);
CREATE INDEX push_rules_enable_user_name on push_rules_enable (user_name);

View file

@ -0,0 +1,22 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS redactions (
event_id TEXT NOT NULL,
redacts TEXT NOT NULL,
UNIQUE (event_id)
);
CREATE INDEX redactions_event_id ON redactions (event_id);
CREATE INDEX redactions_redacts ON redactions (redacts);

View file

@ -0,0 +1,29 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS room_aliases(
room_alias TEXT NOT NULL,
room_id TEXT NOT NULL,
UNIQUE (room_alias)
);
CREATE INDEX room_aliases_id ON room_aliases(room_id);
CREATE TABLE IF NOT EXISTS room_alias_servers(
room_alias TEXT NOT NULL,
server TEXT NOT NULL
);
CREATE INDEX room_alias_servers_alias ON room_alias_servers(room_alias);

View file

@ -0,0 +1,40 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS state_groups(
id BIGINT PRIMARY KEY,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS state_groups_state(
state_group BIGINT NOT NULL,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
event_id TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS event_to_state_groups(
event_id TEXT NOT NULL,
state_group BIGINT NOT NULL,
UNIQUE (event_id)
);
CREATE INDEX state_groups_id ON state_groups(id);
CREATE INDEX state_groups_state_id ON state_groups_state(state_group);
CREATE INDEX state_groups_state_tuple ON state_groups_state(room_id, type, state_key);
CREATE INDEX event_to_state_groups_id ON event_to_state_groups(event_id);

View file

@ -0,0 +1,63 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Stores what transaction ids we have received and what our response was
CREATE TABLE IF NOT EXISTS received_transactions(
transaction_id TEXT,
origin TEXT,
ts BIGINT,
response_code INTEGER,
response_json bytea,
has_been_referenced smallint default 0, -- Whether thishas been referenced by a prev_tx
UNIQUE (transaction_id, origin)
);
CREATE INDEX transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0;
-- Stores what transactions we've sent, what their response was (if we got one) and whether we have
-- since referenced the transaction in another outgoing transaction
CREATE TABLE IF NOT EXISTS sent_transactions(
id BIGINT PRIMARY KEY, -- This is used to apply insertion ordering
transaction_id TEXT,
destination TEXT,
response_code INTEGER DEFAULT 0,
response_json TEXT,
ts BIGINT
);
CREATE INDEX sent_transaction_dest ON sent_transactions(destination);
CREATE INDEX sent_transaction_txn_id ON sent_transactions(transaction_id);
-- So that we can do an efficient look up of all transactions that have yet to be successfully
-- sent.
CREATE INDEX sent_transaction_sent ON sent_transactions(response_code);
-- For sent transactions only.
CREATE TABLE IF NOT EXISTS transaction_id_to_pdu(
transaction_id INTEGER,
destination TEXT,
pdu_id TEXT,
pdu_origin TEXT,
UNIQUE (transaction_id, destination)
);
CREATE INDEX transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination);
-- To track destination health
CREATE TABLE IF NOT EXISTS destinations(
destination TEXT PRIMARY KEY,
retry_last_ts BIGINT,
retry_interval INTEGER
);

View file

@ -0,0 +1,42 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS users(
name TEXT,
password_hash TEXT,
creation_ts BIGINT,
admin SMALLINT DEFAULT 0 NOT NULL,
UNIQUE(name)
);
CREATE TABLE IF NOT EXISTS access_tokens(
id BIGINT PRIMARY KEY,
user_id TEXT NOT NULL,
device_id TEXT,
token TEXT NOT NULL,
last_used BIGINT,
UNIQUE(token)
);
CREATE TABLE IF NOT EXISTS user_ips (
user_id TEXT NOT NULL,
access_token TEXT NOT NULL,
device_id TEXT,
ip TEXT NOT NULL,
user_agent TEXT NOT NULL,
last_seen BIGINT NOT NULL
);
CREATE INDEX user_ips_user ON user_ips(user_id);
CREATE INDEX user_ips_user_ip ON user_ips(user_id, access_token, ip);

View file

@ -14,17 +14,14 @@
*/ */
CREATE TABLE IF NOT EXISTS schema_version( CREATE TABLE IF NOT EXISTS schema_version(
Lock char(1) NOT NULL DEFAULT 'X', -- Makes sure this table only has one row. Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
version INTEGER NOT NULL, version INTEGER NOT NULL,
upgraded BOOL NOT NULL, -- Whether we reached this version from an upgrade or an initial schema. upgraded BOOL NOT NULL, -- Whether we reached this version from an upgrade or an initial schema.
CONSTRAINT schema_version_lock_x CHECK (Lock='X') CHECK (Lock='X')
CONSTRAINT schema_version_lock_uniq UNIQUE (Lock)
); );
CREATE TABLE IF NOT EXISTS applied_schema_deltas( CREATE TABLE IF NOT EXISTS applied_schema_deltas(
version INTEGER NOT NULL, version INTEGER NOT NULL,
file TEXT NOT NULL, file TEXT NOT NULL,
CONSTRAINT schema_deltas_ver_file UNIQUE (version, file) ON CONFLICT IGNORE UNIQUE(version, file)
); );
CREATE INDEX IF NOT EXISTS schema_deltas_ver ON applied_schema_deltas(version);

View file

@ -56,7 +56,6 @@ class SignatureStore(SQLBaseStore):
"algorithm": algorithm, "algorithm": algorithm,
"hash": buffer(hash_bytes), "hash": buffer(hash_bytes),
}, },
or_ignore=True,
) )
def get_event_reference_hashes(self, event_ids): def get_event_reference_hashes(self, event_ids):
@ -100,7 +99,7 @@ class SignatureStore(SQLBaseStore):
" WHERE event_id = ?" " WHERE event_id = ?"
) )
txn.execute(query, (event_id, )) txn.execute(query, (event_id, ))
return dict(txn.fetchall()) return {k: v for k, v in txn.fetchall()}
def _store_event_reference_hash_txn(self, txn, event_id, algorithm, def _store_event_reference_hash_txn(self, txn, event_id, algorithm,
hash_bytes): hash_bytes):
@ -119,7 +118,6 @@ class SignatureStore(SQLBaseStore):
"algorithm": algorithm, "algorithm": algorithm,
"hash": buffer(hash_bytes), "hash": buffer(hash_bytes),
}, },
or_ignore=True,
) )
def _get_event_signatures_txn(self, txn, event_id): def _get_event_signatures_txn(self, txn, event_id):
@ -164,7 +162,6 @@ class SignatureStore(SQLBaseStore):
"key_id": key_id, "key_id": key_id,
"signature": buffer(signature_bytes), "signature": buffer(signature_bytes),
}, },
or_ignore=True,
) )
def _get_prev_event_hashes_txn(self, txn, event_id): def _get_prev_event_hashes_txn(self, txn, event_id):
@ -198,5 +195,4 @@ class SignatureStore(SQLBaseStore):
"algorithm": algorithm, "algorithm": algorithm,
"hash": buffer(hash_bytes), "hash": buffer(hash_bytes),
}, },
or_ignore=True,
) )

View file

@ -17,6 +17,8 @@ from ._base import SQLBaseStore
from twisted.internet import defer from twisted.internet import defer
from synapse.util.stringutils import random_string
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -91,14 +93,15 @@ class StateStore(SQLBaseStore):
state_group = context.state_group state_group = context.state_group
if not state_group: if not state_group:
state_group = self._simple_insert_txn( state_group = self._state_groups_id_gen.get_next_txn(txn)
self._simple_insert_txn(
txn, txn,
table="state_groups", table="state_groups",
values={ values={
"id": state_group,
"room_id": event.room_id, "room_id": event.room_id,
"event_id": event.event_id, "event_id": event.event_id,
}, },
or_ignore=True,
) )
for state in state_events.values(): for state in state_events.values():
@ -112,7 +115,6 @@ class StateStore(SQLBaseStore):
"state_key": state.state_key, "state_key": state.state_key,
"event_id": state.event_id, "event_id": state.event_id,
}, },
or_ignore=True,
) )
self._simple_insert_txn( self._simple_insert_txn(
@ -122,7 +124,6 @@ class StateStore(SQLBaseStore):
"state_group": state_group, "state_group": state_group,
"event_id": event.event_id, "event_id": event.event_id,
}, },
or_replace=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -154,3 +155,7 @@ class StateStore(SQLBaseStore):
events = yield self._parse_events(results) events = yield self._parse_events(results)
defer.returnValue(events) defer.returnValue(events)
def _make_group_id(clock):
return str(int(clock.time_msec())) + random_string(5)

View file

@ -35,7 +35,7 @@ what sort order was used:
from twisted.internet import defer from twisted.internet import defer
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -110,7 +110,7 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
if self.topological is None: if self.topological is None:
return "(%d < %s)" % (self.stream, "stream_ordering") return "(%d < %s)" % (self.stream, "stream_ordering")
else: else:
return "(%d < %s OR (%d == %s AND %d < %s))" % ( return "(%d < %s OR (%d = %s AND %d < %s))" % (
self.topological, "topological_ordering", self.topological, "topological_ordering",
self.topological, "topological_ordering", self.topological, "topological_ordering",
self.stream, "stream_ordering", self.stream, "stream_ordering",
@ -120,7 +120,7 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
if self.topological is None: if self.topological is None:
return "(%d >= %s)" % (self.stream, "stream_ordering") return "(%d >= %s)" % (self.stream, "stream_ordering")
else: else:
return "(%d > %s OR (%d == %s AND %d >= %s))" % ( return "(%d > %s OR (%d = %s AND %d >= %s))" % (
self.topological, "topological_ordering", self.topological, "topological_ordering",
self.topological, "topological_ordering", self.topological, "topological_ordering",
self.stream, "stream_ordering", self.stream, "stream_ordering",
@ -240,7 +240,7 @@ class StreamStore(SQLBaseStore):
sql = ( sql = (
"SELECT e.event_id, e.stream_ordering FROM events AS e WHERE " "SELECT e.event_id, e.stream_ordering FROM events AS e WHERE "
"(e.outlier = 0 AND (room_id IN (%(current)s)) OR " "(e.outlier = ? AND (room_id IN (%(current)s)) OR "
"(event_id IN (%(invites)s))) " "(event_id IN (%(invites)s))) "
"AND e.stream_ordering > ? AND e.stream_ordering <= ? " "AND e.stream_ordering > ? AND e.stream_ordering <= ? "
"ORDER BY stream_ordering ASC LIMIT %(limit)d " "ORDER BY stream_ordering ASC LIMIT %(limit)d "
@ -251,7 +251,7 @@ class StreamStore(SQLBaseStore):
} }
def f(txn): def f(txn):
txn.execute(sql, (user_id, user_id, from_id.stream, to_id.stream,)) txn.execute(sql, (False, user_id, user_id, from_id.stream, to_id.stream,))
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
@ -283,7 +283,7 @@ class StreamStore(SQLBaseStore):
# Tokens really represent positions between elements, but we use # Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence # the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities. # we have a bit of asymmetry when it comes to equalities.
args = [room_id] args = [False, room_id]
if direction == 'b': if direction == 'b':
order = "DESC" order = "DESC"
bounds = _StreamToken.parse(from_key).upper_bound() bounds = _StreamToken.parse(from_key).upper_bound()
@ -307,7 +307,7 @@ class StreamStore(SQLBaseStore):
sql = ( sql = (
"SELECT * FROM events" "SELECT * FROM events"
" WHERE outlier = 0 AND room_id = ? AND %(bounds)s" " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
" ORDER BY topological_ordering %(order)s," " ORDER BY topological_ordering %(order)s,"
" stream_ordering %(order)s %(limit)s" " stream_ordering %(order)s %(limit)s"
) % { ) % {
@ -358,7 +358,7 @@ class StreamStore(SQLBaseStore):
sql = ( sql = (
"SELECT stream_ordering, topological_ordering, event_id" "SELECT stream_ordering, topological_ordering, event_id"
" FROM events" " FROM events"
" WHERE room_id = ? AND stream_ordering <= ? AND outlier = 0" " WHERE room_id = ? AND stream_ordering <= ? AND outlier = ?"
" ORDER BY topological_ordering DESC, stream_ordering DESC" " ORDER BY topological_ordering DESC, stream_ordering DESC"
" LIMIT ?" " LIMIT ?"
) )
@ -368,17 +368,17 @@ class StreamStore(SQLBaseStore):
"SELECT stream_ordering, topological_ordering, event_id" "SELECT stream_ordering, topological_ordering, event_id"
" FROM events" " FROM events"
" WHERE room_id = ? AND stream_ordering > ?" " WHERE room_id = ? AND stream_ordering > ?"
" AND stream_ordering <= ? AND outlier = 0" " AND stream_ordering <= ? AND outlier = ?"
" ORDER BY topological_ordering DESC, stream_ordering DESC" " ORDER BY topological_ordering DESC, stream_ordering DESC"
" LIMIT ?" " LIMIT ?"
) )
def get_recent_events_for_room_txn(txn): def get_recent_events_for_room_txn(txn):
if from_token is None: if from_token is None:
txn.execute(sql, (room_id, end_token.stream, limit,)) txn.execute(sql, (room_id, end_token.stream, False, limit,))
else: else:
txn.execute(sql, ( txn.execute(sql, (
room_id, from_token.stream, end_token.stream, limit room_id, from_token.stream, end_token.stream, False, limit
)) ))
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
@ -413,12 +413,10 @@ class StreamStore(SQLBaseStore):
"get_recent_events_for_room", get_recent_events_for_room_txn "get_recent_events_for_room", get_recent_events_for_room_txn
) )
@cached(num_args=0) @defer.inlineCallbacks
def get_room_events_max_id(self): def get_room_events_max_id(self):
return self.runInteraction( token = yield self._stream_id_gen.get_max_token(self)
"get_room_events_max_id", defer.returnValue("s%d" % (token,))
self._get_room_events_max_id_txn
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_min_token(self): def _get_min_token(self):
@ -433,27 +431,6 @@ class StreamStore(SQLBaseStore):
defer.returnValue(self.min_token) defer.returnValue(self.min_token)
def get_next_stream_id(self):
with self._next_stream_id_lock:
i = self._next_stream_id
self._next_stream_id += 1
return i
def _get_room_events_max_id_txn(self, txn):
txn.execute(
"SELECT MAX(stream_ordering) as m FROM events"
)
res = self.cursor_to_dict(txn)
logger.debug("get_room_events_max_id: %s", res)
if not res or not res[0] or not res[0]["m"]:
return "s0"
key = res[0]["m"]
return "s%d" % (key,)
@staticmethod @staticmethod
def _set_before_and_after(events, rows): def _set_before_and_after(events, rows):
for event, row in zip(events, rows): for event, row in zip(events, rows):

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, Table, cached from ._base import SQLBaseStore, cached
from collections import namedtuple from collections import namedtuple
@ -76,22 +76,18 @@ class TransactionStore(SQLBaseStore):
response_json (str) response_json (str)
""" """
return self.runInteraction( return self._simple_insert(
"set_received_txn_response", table=ReceivedTransactionsTable.table_name,
self._set_received_txn_response, values={
transaction_id, origin, code, response_dict "transaction_id": transaction_id,
"origin": origin,
"response_code": code,
"response_json": response_dict,
},
or_ignore=True,
desc="set_received_txn_response",
) )
def _set_received_txn_response(self, txn, transaction_id, origin, code,
response_json):
query = (
"UPDATE %s "
"SET response_code = ?, response_json = ? "
"WHERE transaction_id = ? AND origin = ?"
) % ReceivedTransactionsTable.table_name
txn.execute(query, (code, response_json, transaction_id, origin))
def prep_send_transaction(self, transaction_id, destination, def prep_send_transaction(self, transaction_id, destination,
origin_server_ts): origin_server_ts):
"""Persists an outgoing transaction and calculates the values for the """Persists an outgoing transaction and calculates the values for the
@ -118,41 +114,38 @@ class TransactionStore(SQLBaseStore):
def _prep_send_transaction(self, txn, transaction_id, destination, def _prep_send_transaction(self, txn, transaction_id, destination,
origin_server_ts): origin_server_ts):
next_id = self._transaction_id_gen.get_next_txn(txn)
# First we find out what the prev_txns should be. # First we find out what the prev_txns should be.
# Since we know that we are only sending one transaction at a time, # Since we know that we are only sending one transaction at a time,
# we can simply take the last one. # we can simply take the last one.
query = "%s ORDER BY id DESC LIMIT 1" % ( query = (
SentTransactions.select_statement("destination = ?"), "SELECT * FROM sent_transactions"
) " WHERE destination = ?"
" ORDER BY id DESC LIMIT 1"
)
results = txn.execute(query, (destination,)) txn.execute(query, (destination,))
results = SentTransactions.decode_results(results) results = self.cursor_to_dict(txn)
prev_txns = [r.transaction_id for r in results] prev_txns = [r["transaction_id"] for r in results]
# Actually add the new transaction to the sent_transactions table. # Actually add the new transaction to the sent_transactions table.
query = SentTransactions.insert_statement() self._simple_insert_txn(
txn.execute(query, SentTransactions.EntryType( txn,
None, table=SentTransactions.table_name,
transaction_id=transaction_id, values={
destination=destination, "id": next_id,
ts=origin_server_ts, "transaction_id": transaction_id,
response_code=0, "destination": destination,
response_json=None "ts": origin_server_ts,
)) "response_code": 0,
"response_json": None,
}
)
# Update the tx id -> pdu id mapping # TODO Update the tx id -> pdu id mapping
# values = [
# (transaction_id, destination, pdu[0], pdu[1])
# for pdu in pdu_list
# ]
#
# logger.debug("Inserting: %s", repr(values))
#
# query = TransactionsToPduTable.insert_statement()
# txn.executemany(query, values)
return prev_txns return prev_txns
@ -171,15 +164,20 @@ class TransactionStore(SQLBaseStore):
transaction_id, destination, code, response_dict transaction_id, destination, code, response_dict
) )
def _delivered_txn(cls, txn, transaction_id, destination, def _delivered_txn(self, txn, transaction_id, destination,
code, response_json): code, response_json):
query = ( self._simple_update_one_txn(
"UPDATE %s " txn,
"SET response_code = ?, response_json = ? " table=SentTransactions.table_name,
"WHERE transaction_id = ? AND destination = ?" keyvalues={
) % SentTransactions.table_name "transaction_id": transaction_id,
"destination": destination,
txn.execute(query, (code, response_json, transaction_id, destination)) },
updatevalues={
"response_code": code,
"response_json": None, # For now, don't persist response_json
}
)
def get_transactions_after(self, transaction_id, destination): def get_transactions_after(self, transaction_id, destination):
"""Get all transactions after a given local transaction_id. """Get all transactions after a given local transaction_id.
@ -189,25 +187,26 @@ class TransactionStore(SQLBaseStore):
destination (str) destination (str)
Returns: Returns:
list: A list of `ReceivedTransactionsTable.EntryType` list: A list of dicts
""" """
return self.runInteraction( return self.runInteraction(
"get_transactions_after", "get_transactions_after",
self._get_transactions_after, transaction_id, destination self._get_transactions_after, transaction_id, destination
) )
def _get_transactions_after(cls, txn, transaction_id, destination): def _get_transactions_after(self, txn, transaction_id, destination):
where = ( query = (
"destination = ? AND id > (select id FROM %s WHERE " "SELECT * FROM sent_transactions"
"transaction_id = ? AND destination = ?)" " WHERE destination = ? AND id >"
) % ( " ("
SentTransactions.table_name " SELECT id FROM sent_transactions"
" WHERE transaction_id = ? AND destination = ?"
" )"
) )
query = SentTransactions.select_statement(where)
txn.execute(query, (destination, transaction_id, destination)) txn.execute(query, (destination, transaction_id, destination))
return ReceivedTransactionsTable.decode_results(txn.fetchall()) return self.cursor_to_dict(txn)
@cached() @cached()
def get_destination_retry_timings(self, destination): def get_destination_retry_timings(self, destination):
@ -218,22 +217,27 @@ class TransactionStore(SQLBaseStore):
Returns: Returns:
None if not retrying None if not retrying
Otherwise a DestinationsTable.EntryType for the retry scheme Otherwise a dict for the retry scheme
""" """
return self.runInteraction( return self.runInteraction(
"get_destination_retry_timings", "get_destination_retry_timings",
self._get_destination_retry_timings, destination) self._get_destination_retry_timings, destination)
def _get_destination_retry_timings(cls, txn, destination): def _get_destination_retry_timings(self, txn, destination):
query = DestinationsTable.select_statement("destination = ?") result = self._simple_select_one_txn(
txn.execute(query, (destination,)) txn,
result = txn.fetchall() table=DestinationsTable.table_name,
if result: keyvalues={
result = DestinationsTable.decode_single_result(result) "destination": destination,
if result.retry_last_ts > 0: },
return result retcols=DestinationsTable.fields,
else: allow_none=True,
return None )
if result and result["retry_last_ts"] > 0:
return result
else:
return None
def set_destination_retry_timings(self, destination, def set_destination_retry_timings(self, destination,
retry_last_ts, retry_interval): retry_last_ts, retry_interval):
@ -249,11 +253,11 @@ class TransactionStore(SQLBaseStore):
# As this is the new value, we might as well prefill the cache # As this is the new value, we might as well prefill the cache
self.get_destination_retry_timings.prefill( self.get_destination_retry_timings.prefill(
destination, destination,
DestinationsTable.EntryType( {
destination, "destination": destination,
retry_last_ts, "retry_last_ts": retry_last_ts,
retry_interval "retry_interval": retry_interval
) },
) )
# XXX: we could chose to not bother persisting this if our cache thinks # XXX: we could chose to not bother persisting this if our cache thinks
@ -266,22 +270,38 @@ class TransactionStore(SQLBaseStore):
retry_interval, retry_interval,
) )
def _set_destination_retry_timings(cls, txn, destination, def _set_destination_retry_timings(self, txn, destination,
retry_last_ts, retry_interval): retry_last_ts, retry_interval):
query = ( query = (
"INSERT OR REPLACE INTO %s " "UPDATE destinations"
"(destination, retry_last_ts, retry_interval) " " SET retry_last_ts = ?, retry_interval = ?"
"VALUES (?, ?, ?) " " WHERE destination = ?"
) % DestinationsTable.table_name )
txn.execute(query, (destination, retry_last_ts, retry_interval)) txn.execute(
query,
(
retry_last_ts, retry_interval, destination,
)
)
if txn.rowcount == 0:
# destination wasn't already in table. Insert it.
self._simple_insert_txn(
txn,
table="destinations",
values={
"destination": destination,
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval,
}
)
def get_destinations_needing_retry(self): def get_destinations_needing_retry(self):
"""Get all destinations which are due a retry for sending a transaction. """Get all destinations which are due a retry for sending a transaction.
Returns: Returns:
list: A list of `DestinationsTable.EntryType` list: A list of dicts
""" """
return self.runInteraction( return self.runInteraction(
@ -289,14 +309,17 @@ class TransactionStore(SQLBaseStore):
self._get_destinations_needing_retry self._get_destinations_needing_retry
) )
def _get_destinations_needing_retry(cls, txn): def _get_destinations_needing_retry(self, txn):
where = "retry_last_ts > 0 and retry_next_ts < now()" query = (
query = DestinationsTable.select_statement(where) "SELECT * FROM destinations"
txn.execute(query) " WHERE retry_last_ts > 0 and retry_next_ts < ?"
return DestinationsTable.decode_results(txn.fetchall()) )
txn.execute(query, (self._clock.time_msec(),))
return self.cursor_to_dict(txn)
class ReceivedTransactionsTable(Table): class ReceivedTransactionsTable(object):
table_name = "received_transactions" table_name = "received_transactions"
fields = [ fields = [
@ -308,10 +331,8 @@ class ReceivedTransactionsTable(Table):
"has_been_referenced", "has_been_referenced",
] ]
EntryType = namedtuple("ReceivedTransactionsEntry", fields)
class SentTransactions(object):
class SentTransactions(Table):
table_name = "sent_transactions" table_name = "sent_transactions"
fields = [ fields = [
@ -326,7 +347,7 @@ class SentTransactions(Table):
EntryType = namedtuple("SentTransactionsEntry", fields) EntryType = namedtuple("SentTransactionsEntry", fields)
class TransactionsToPduTable(Table): class TransactionsToPduTable(object):
table_name = "transaction_id_to_pdu" table_name = "transaction_id_to_pdu"
fields = [ fields = [
@ -336,10 +357,8 @@ class TransactionsToPduTable(Table):
"pdu_origin", "pdu_origin",
] ]
EntryType = namedtuple("TransactionsToPduEntry", fields)
class DestinationsTable(object):
class DestinationsTable(Table):
table_name = "destinations" table_name = "destinations"
fields = [ fields = [
@ -347,5 +366,3 @@ class DestinationsTable(Table):
"retry_last_ts", "retry_last_ts",
"retry_interval", "retry_interval",
] ]
EntryType = namedtuple("DestinationsEntry", fields)

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,131 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from collections import deque
import contextlib
import threading
class IdGenerator(object):
def __init__(self, table, column, store):
self.table = table
self.column = column
self.store = store
self._lock = threading.Lock()
self._next_id = None
@defer.inlineCallbacks
def get_next(self):
with self._lock:
if not self._next_id:
res = yield self.store._execute_and_decode(
"IdGenerator_%s" % (self.table,),
"SELECT MAX(%s) as mx FROM %s" % (self.column, self.table,)
)
self._next_id = (res and res[0] and res[0]["mx"]) or 1
i = self._next_id
self._next_id += 1
defer.returnValue(i)
def get_next_txn(self, txn):
with self._lock:
if self._next_id:
i = self._next_id
self._next_id += 1
return i
else:
txn.execute(
"SELECT MAX(%s) FROM %s" % (self.column, self.table,)
)
val, = txn.fetchone()
cur = val or 0
cur += 1
self._next_id = cur + 1
return cur
class StreamIdGenerator(object):
"""Used to generate new stream ids when persisting events while keeping
track of which transactions have been completed.
This allows us to get the "current" stream id, i.e. the stream id such that
all ids less than or equal to it have completed. This handles the fact that
persistence of events can complete out of order.
Usage:
with stream_id_gen.get_next_txn(txn) as stream_id:
# ... persist event ...
"""
def __init__(self):
self._lock = threading.Lock()
self._current_max = None
self._unfinished_ids = deque()
def get_next_txn(self, txn):
"""
Usage:
with stream_id_gen.get_next_txn(txn) as stream_id:
# ... persist event ...
"""
with self._lock:
if not self._current_max:
self._compute_current_max(txn)
self._current_max += 1
next_id = self._current_max
self._unfinished_ids.append(next_id)
@contextlib.contextmanager
def manager():
try:
yield next_id
finally:
with self._lock:
self._unfinished_ids.remove(next_id)
return manager()
@defer.inlineCallbacks
def get_max_token(self, store):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
with self._lock:
if self._unfinished_ids:
defer.returnValue(self._unfinished_ids[0] - 1)
if not self._current_max:
yield store.runInteraction(
"_compute_current_max",
self._compute_current_max,
)
defer.returnValue(self._current_max)
def _compute_current_max(self, txn):
txn.execute("SELECT MAX(stream_ordering) FROM events")
val, = txn.fetchone()
self._current_max = int(val) if val else 1
return self._current_max

View file

@ -14,6 +14,10 @@
# limitations under the License. # limitations under the License.
from functools import wraps
import threading
class LruCache(object): class LruCache(object):
"""Least-recently-used cache.""" """Least-recently-used cache."""
# TODO(mjark) Add mutex for linked list for thread safety. # TODO(mjark) Add mutex for linked list for thread safety.
@ -24,6 +28,16 @@ class LruCache(object):
PREV, NEXT, KEY, VALUE = 0, 1, 2, 3 PREV, NEXT, KEY, VALUE = 0, 1, 2, 3
lock = threading.Lock()
def synchronized(f):
@wraps(f)
def inner(*args, **kwargs):
with lock:
return f(*args, **kwargs)
return inner
def add_node(key, value): def add_node(key, value):
prev_node = list_root prev_node = list_root
next_node = prev_node[NEXT] next_node = prev_node[NEXT]
@ -51,6 +65,7 @@ class LruCache(object):
next_node[PREV] = prev_node next_node[PREV] = prev_node
cache.pop(node[KEY], None) cache.pop(node[KEY], None)
@synchronized
def cache_get(key, default=None): def cache_get(key, default=None):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
@ -59,6 +74,7 @@ class LruCache(object):
else: else:
return default return default
@synchronized
def cache_set(key, value): def cache_set(key, value):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
@ -69,6 +85,7 @@ class LruCache(object):
if len(cache) > max_size: if len(cache) > max_size:
delete_node(list_root[PREV]) delete_node(list_root[PREV])
@synchronized
def cache_set_default(key, value): def cache_set_default(key, value):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
@ -79,6 +96,7 @@ class LruCache(object):
delete_node(list_root[PREV]) delete_node(list_root[PREV])
return value return value
@synchronized
def cache_pop(key, default=None): def cache_pop(key, default=None):
node = cache.get(key, None) node = cache.get(key, None)
if node: if node:
@ -87,9 +105,11 @@ class LruCache(object):
else: else:
return default return default
@synchronized
def cache_len(): def cache_len():
return len(cache) return len(cache)
@synchronized
def cache_contains(key): def cache_contains(key):
return key in cache return key in cache

View file

@ -60,7 +60,7 @@ def get_retry_limiter(destination, clock, store, **kwargs):
if retry_timings: if retry_timings:
retry_last_ts, retry_interval = ( retry_last_ts, retry_interval = (
retry_timings.retry_last_ts, retry_timings.retry_interval retry_timings["retry_last_ts"], retry_timings["retry_interval"]
) )
now = int(clock.time_msec()) now = int(clock.time_msec())

View file

@ -24,8 +24,6 @@ from ..utils import MockHttpResource, MockClock, setup_test_homeserver
from synapse.federation import initialize_http_replication from synapse.federation import initialize_http_replication
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.storage.transactions import DestinationsTable
def make_pdu(prev_pdus=[], **kwargs): def make_pdu(prev_pdus=[], **kwargs):
"""Provide some default fields for making a PduTuple.""" """Provide some default fields for making a PduTuple."""
@ -57,8 +55,14 @@ class FederationTestCase(unittest.TestCase):
self.mock_persistence.get_received_txn_response.return_value = ( self.mock_persistence.get_received_txn_response.return_value = (
defer.succeed(None) defer.succeed(None)
) )
retry_timings_res = {
"destination": "",
"retry_last_ts": 0,
"retry_interval": 0,
}
self.mock_persistence.get_destination_retry_timings.return_value = ( self.mock_persistence.get_destination_retry_timings.return_value = (
defer.succeed(DestinationsTable.EntryType("", 0, 0)) defer.succeed(retry_timings_res)
) )
self.mock_persistence.get_auth_chain.return_value = [] self.mock_persistence.get_auth_chain.return_value = []
self.clock = MockClock() self.clock = MockClock()

View file

@ -87,6 +87,15 @@ class FederationTestCase(unittest.TestCase):
self.datastore.get_room.return_value = defer.succeed(True) self.datastore.get_room.return_value = defer.succeed(True)
self.auth.check_host_in_room.return_value = defer.succeed(True) self.auth.check_host_in_room.return_value = defer.succeed(True)
retry_timings_res = {
"destination": "",
"retry_last_ts": 0,
"retry_interval": 0,
}
self.datastore.get_destination_retry_timings.return_value = (
defer.succeed(retry_timings_res)
)
def have_events(event_ids): def have_events(event_ids):
return defer.succeed({}) return defer.succeed({})
self.datastore.have_events.side_effect = have_events self.datastore.have_events.side_effect = have_events

View file

@ -194,8 +194,13 @@ class MockedDatastorePresenceTestCase(PresenceTestCase):
return datastore return datastore
def setUp_datastore_federation_mocks(self, datastore): def setUp_datastore_federation_mocks(self, datastore):
retry_timings_res = {
"destination": "",
"retry_last_ts": 0,
"retry_interval": 0,
}
datastore.get_destination_retry_timings.return_value = ( datastore.get_destination_retry_timings.return_value = (
defer.succeed(DestinationsTable.EntryType("", 0, 0)) defer.succeed(retry_timings_res)
) )
def get_received_txn_response(*args): def get_received_txn_response(*args):

View file

@ -96,8 +96,13 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.event_source = hs.get_event_sources().sources["typing"] self.event_source = hs.get_event_sources().sources["typing"]
self.datastore = hs.get_datastore() self.datastore = hs.get_datastore()
retry_timings_res = {
"destination": "",
"retry_last_ts": 0,
"retry_interval": 0,
}
self.datastore.get_destination_retry_timings.return_value = ( self.datastore.get_destination_retry_timings.return_value = (
defer.succeed(DestinationsTable.EntryType("", 0, 0)) defer.succeed(retry_timings_res)
) )
def get_received_txn_response(*args): def get_received_txn_response(*args):

View file

@ -115,12 +115,6 @@ class EventStreamPermissionsTestCase(RestTestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
http_client=None, http_client=None,
replication_layer=Mock(), replication_layer=Mock(),
clock=Mock(spec=[
"call_later",
"cancel_call_later",
"time_msec",
"time"
]),
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]), ]),
@ -132,9 +126,6 @@ class EventStreamPermissionsTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
hs.get_clock().time_msec.return_value = 1000000
hs.get_clock().time.return_value = 1000
synapse.rest.client.v1.register.register_servlets(hs, self.mock_resource) synapse.rest.client.v1.register.register_servlets(hs, self.mock_resource)
synapse.rest.client.v1.events.register_servlets(hs, self.mock_resource) synapse.rest.client.v1.events.register_servlets(hs, self.mock_resource)
synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)

View file

@ -15,6 +15,7 @@
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from tests.utils import setup_test_homeserver
from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.appservice import ( from synapse.storage.appservice import (
@ -33,14 +34,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.as_yaml_files = [] self.as_yaml_files = []
db_pool = SQLiteMemoryDbPool() config = Mock(
yield db_pool.prepare() app_service_config_files=self.as_yaml_files
hs = HomeServer(
"test", db_pool=db_pool, clock=MockClock(),
config=Mock(
app_service_config_files=self.as_yaml_files
)
) )
hs = yield setup_test_homeserver(config=config)
self.as_token = "token1" self.as_token = "token1"
self.as_url = "some_url" self.as_url = "some_url"
@ -102,8 +99,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.as_yaml_files = [] self.as_yaml_files = []
self.db_pool = SQLiteMemoryDbPool()
yield self.db_pool.prepare() config = Mock(
app_service_config_files=self.as_yaml_files
)
hs = yield setup_test_homeserver(config=config)
self.db_pool = hs.get_db_pool()
self.as_list = [ self.as_list = [
{ {
"token": "token1", "token": "token1",
@ -129,11 +131,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
for s in self.as_list: for s in self.as_list:
yield self._add_service(s["url"], s["token"]) yield self._add_service(s["url"], s["token"])
hs = HomeServer( self.as_yaml_files = []
"test", db_pool=self.db_pool, clock=MockClock(), config=Mock(
app_service_config_files=self.as_yaml_files
)
)
self.store = TestTransactionStore(hs) self.store = TestTransactionStore(hs)
def _add_service(self, url, as_token): def _add_service(self, url, as_token):
@ -302,7 +301,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
(service.id,) (service.id,)
) )
self.assertEquals(1, len(res)) self.assertEquals(1, len(res))
self.assertEquals(str(txn_id), res[0][0]) self.assertEquals(txn_id, res[0][0])
res = yield self.db_pool.runQuery( res = yield self.db_pool.runQuery(
"SELECT * FROM application_services_txns WHERE txn_id=?", "SELECT * FROM application_services_txns WHERE txn_id=?",
@ -325,7 +324,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
(service.id,) (service.id,)
) )
self.assertEquals(1, len(res)) self.assertEquals(1, len(res))
self.assertEquals(str(txn_id), res[0][0]) self.assertEquals(txn_id, res[0][0])
self.assertEquals(ApplicationServiceState.UP, res[0][1]) self.assertEquals(ApplicationServiceState.UP, res[0][1])
res = yield self.db_pool.runQuery( res = yield self.db_pool.runQuery(

View file

@ -24,6 +24,7 @@ from collections import OrderedDict
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import create_engine
class SQLBaseStoreTestCase(unittest.TestCase): class SQLBaseStoreTestCase(unittest.TestCase):
@ -32,15 +33,26 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.db_pool = Mock(spec=["runInteraction"]) self.db_pool = Mock(spec=["runInteraction"])
self.mock_txn = Mock() self.mock_txn = Mock()
self.mock_conn = Mock(spec_set=["cursor"])
self.mock_conn.cursor.return_value = self.mock_txn
# Our fake runInteraction just runs synchronously inline # Our fake runInteraction just runs synchronously inline
def runInteraction(func, *args, **kwargs): def runInteraction(func, *args, **kwargs):
return defer.succeed(func(self.mock_txn, *args, **kwargs)) return defer.succeed(func(self.mock_txn, *args, **kwargs))
self.db_pool.runInteraction = runInteraction self.db_pool.runInteraction = runInteraction
def runWithConnection(func, *args, **kwargs):
return defer.succeed(func(self.mock_conn, *args, **kwargs))
self.db_pool.runWithConnection = runWithConnection
config = Mock() config = Mock()
config.event_cache_size = 1 config.event_cache_size = 1
hs = HomeServer("test", db_pool=self.db_pool, config=config) hs = HomeServer(
"test",
db_pool=self.db_pool,
config=config,
database_engine=create_engine("sqlite3"),
)
self.datastore = SQLBaseStore(hs) self.datastore = SQLBaseStore(hs)
@ -86,8 +98,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.assertEquals("Value", value) self.assertEquals("Value", value)
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(
"SELECT retcol FROM tablename WHERE keycol = ? " "SELECT retcol FROM tablename WHERE keycol = ?",
"ORDER BY rowid asc",
["TheKey"] ["TheKey"]
) )
@ -104,8 +115,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret) self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(
"SELECT colA, colB, colC FROM tablename WHERE keycol = ? " "SELECT colA, colB, colC FROM tablename WHERE keycol = ?",
"ORDER BY rowid asc",
["TheKey"] ["TheKey"]
) )
@ -139,8 +149,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(
"SELECT colA FROM tablename WHERE keycol = ? " "SELECT colA FROM tablename WHERE keycol = ?",
"ORDER BY rowid asc",
["A set"] ["A set"]
) )
@ -189,8 +198,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.assertEquals({"columname": "Old Value"}, ret) self.assertEquals({"columname": "Old Value"}, ret)
self.mock_txn.execute.assert_has_calls([ self.mock_txn.execute.assert_has_calls([
call('SELECT columname FROM tablename WHERE keycol = ? ' call('SELECT columname FROM tablename WHERE keycol = ?',
'ORDER BY rowid asc',
['TheKey']), ['TheKey']),
call("UPDATE tablename SET columname = ? WHERE keycol = ?", call("UPDATE tablename SET columname = ? WHERE keycol = ?",
["New Value", "TheKey"]) ["New Value", "TheKey"])

View file

@ -38,31 +38,42 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_register(self): def test_register(self):
yield self.store.register(self.user_id, self.tokens[0], self.pwhash) yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
u = (yield self.store.get_user_by_id(self.user_id))[0]
# TODO(paul): Surely this field should be 'user_id', not 'name'
# Additionally surely it shouldn't come in a 1-element list
self.assertEquals(self.user_id, u['name'])
self.assertEquals(self.pwhash, u['password_hash'])
self.assertEquals( self.assertEquals(
{"admin": 0, # TODO(paul): Surely this field should be 'user_id', not 'name'
"device_id": None, # Additionally surely it shouldn't come in a 1-element list
"name": self.user_id, {"name": self.user_id, "password_hash": self.pwhash},
"token_id": 1}, (yield self.store.get_user_by_id(self.user_id))
(yield self.store.get_user_by_token(self.tokens[0]))
) )
result = yield self.store.get_user_by_token(self.tokens[1])
self.assertDictContainsSubset(
{
"admin": 0,
"device_id": None,
"name": self.user_id,
},
result
)
self.assertTrue("token_id" in result)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_add_tokens(self): def test_add_tokens(self):
yield self.store.register(self.user_id, self.tokens[0], self.pwhash) yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1]) yield self.store.add_access_token_to_user(self.user_id, self.tokens[1])
self.assertEquals( result = yield self.store.get_user_by_token(self.tokens[1])
{"admin": 0,
"device_id": None, self.assertDictContainsSubset(
"name": self.user_id, {
"token_id": 2}, "admin": 0,
(yield self.store.get_user_by_token(self.tokens[1])) "device_id": None,
"name": self.user_id,
},
result
) )
self.assertTrue("token_id" in result)

View file

@ -119,7 +119,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN) yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
self.assertEquals( self.assertEquals(
["test"], {"test"},
(yield self.store.get_joined_hosts_for_room(self.room.to_string())) (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
) )
@ -127,7 +127,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN) yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
self.assertEquals( self.assertEquals(
["test"], {"test"},
(yield self.store.get_joined_hosts_for_room(self.room.to_string())) (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
) )
@ -136,9 +136,9 @@ class RoomMemberStoreTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
{"test", "elsewhere"}, {"test", "elsewhere"},
set((yield (yield
self.store.get_joined_hosts_for_room(self.room.to_string()) self.store.get_joined_hosts_for_room(self.room.to_string())
)) )
) )
# Should still have both hosts # Should still have both hosts
@ -146,15 +146,15 @@ class RoomMemberStoreTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
{"test", "elsewhere"}, {"test", "elsewhere"},
set((yield (yield
self.store.get_joined_hosts_for_room(self.room.to_string()) self.store.get_joined_hosts_for_room(self.room.to_string())
)) )
) )
# Should have only one host after other leaves # Should have only one host after other leaves
yield self.inject_room_member(self.room, self.u_charlie, Membership.LEAVE) yield self.inject_room_member(self.room, self.u_charlie, Membership.LEAVE)
self.assertEquals( self.assertEquals(
["test"], {"test"},
(yield self.store.get_joined_hosts_for_room(self.room.to_string())) (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
) )

View file

@ -17,6 +17,7 @@ from synapse.http.server import HttpServer
from synapse.api.errors import cs_error, CodeMessageException, StoreError from synapse.api.errors import cs_error, CodeMessageException, StoreError
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.storage import prepare_database from synapse.storage import prepare_database
from synapse.storage.engines import create_engine
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
@ -44,18 +45,23 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.event_cache_size = 1 config.event_cache_size = 1
config.disable_registration = False config.disable_registration = False
if "clock" not in kargs:
kargs["clock"] = MockClock()
if datastore is None: if datastore is None:
db_pool = SQLiteMemoryDbPool() db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare() yield db_pool.prepare()
hs = HomeServer( hs = HomeServer(
name, db_pool=db_pool, config=config, name, db_pool=db_pool, config=config,
version_string="Synapse/tests", version_string="Synapse/tests",
database_engine=create_engine("sqlite3"),
**kargs **kargs
) )
else: else:
hs = HomeServer( hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config, name, db_pool=None, datastore=datastore, config=config,
version_string="Synapse/tests", version_string="Synapse/tests",
database_engine=create_engine("sqlite3"),
**kargs **kargs
) )
@ -227,7 +233,10 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
) )
def prepare(self): def prepare(self):
return self.runWithConnection(prepare_database) engine = create_engine("sqlite3")
return self.runWithConnection(
lambda conn: prepare_database(conn, engine)
)
class MemoryDataStore(object): class MemoryDataStore(object):