mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-11 20:42:23 +01:00
Make synapse._scripts
pass typechecks (#12421)
This commit is contained in:
parent
dd5cc37aa4
commit
0cd182f296
6 changed files with 50 additions and 43 deletions
1
changelog.d/12421.misc
Normal file
1
changelog.d/12421.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Make `synapse._scripts` pass type checks.
|
5
mypy.ini
5
mypy.ini
|
@ -28,11 +28,6 @@ exclude = (?x)
|
||||||
|scripts-dev/federation_client.py
|
|scripts-dev/federation_client.py
|
||||||
|scripts-dev/release.py
|
|scripts-dev/release.py
|
||||||
|
|
||||||
|synapse/_scripts/export_signing_key.py
|
|
||||||
|synapse/_scripts/move_remote_media_to_new_store.py
|
|
||||||
|synapse/_scripts/synapse_port_db.py
|
|
||||||
|synapse/_scripts/update_synapse_database.py
|
|
||||||
|
|
||||||
|synapse/storage/databases/__init__.py
|
|synapse/storage/databases/__init__.py
|
||||||
|synapse/storage/databases/main/cache.py
|
|synapse/storage/databases/main/cache.py
|
||||||
|synapse/storage/databases/main/devices.py
|
|synapse/storage/databases/main/devices.py
|
||||||
|
|
|
@ -17,8 +17,8 @@ import sys
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import nacl.signing
|
|
||||||
from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys
|
from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys
|
||||||
|
from signedjson.types import VerifyKey
|
||||||
|
|
||||||
|
|
||||||
def exit(status: int = 0, message: Optional[str] = None):
|
def exit(status: int = 0, message: Optional[str] = None):
|
||||||
|
@ -27,7 +27,7 @@ def exit(status: int = 0, message: Optional[str] = None):
|
||||||
sys.exit(status)
|
sys.exit(status)
|
||||||
|
|
||||||
|
|
||||||
def format_plain(public_key: nacl.signing.VerifyKey):
|
def format_plain(public_key: VerifyKey):
|
||||||
print(
|
print(
|
||||||
"%s:%s %s"
|
"%s:%s %s"
|
||||||
% (
|
% (
|
||||||
|
@ -38,7 +38,7 @@ def format_plain(public_key: nacl.signing.VerifyKey):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int):
|
def format_for_config(public_key: VerifyKey, expiry_ts: int):
|
||||||
print(
|
print(
|
||||||
' "%s:%s": { key: "%s", expired_ts: %i }'
|
' "%s:%s": { key: "%s", expired_ts: %i }'
|
||||||
% (
|
% (
|
||||||
|
|
|
@ -109,10 +109,9 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("dest_repo", help="Path to source content repo")
|
parser.add_argument("dest_repo", help="Path to source content repo")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
logging_config = {
|
logging.basicConfig(
|
||||||
"level": logging.DEBUG if args.v else logging.INFO,
|
level=logging.DEBUG if args.v else logging.INFO,
|
||||||
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
|
format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
|
||||||
}
|
)
|
||||||
logging.basicConfig(**logging_config)
|
|
||||||
|
|
||||||
main(args.src_repo, args.dest_repo)
|
main(args.src_repo, args.dest_repo)
|
||||||
|
|
|
@ -21,12 +21,13 @@ import logging
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Dict, Iterable, Optional, Set
|
from types import TracebackType
|
||||||
|
from typing import Dict, Iterable, Optional, Set, Tuple, Type, cast
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from matrix_common.versionstring import get_distribution_version_string
|
from matrix_common.versionstring import get_distribution_version_string
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor as reactor_
|
||||||
|
|
||||||
from synapse.config.database import DatabaseConnectionConfig
|
from synapse.config.database import DatabaseConnectionConfig
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
@ -66,8 +67,12 @@ from synapse.storage.databases.main.user_directory import (
|
||||||
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
|
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.storage.prepare_database import prepare_database
|
from synapse.storage.prepare_database import prepare_database
|
||||||
|
from synapse.types import ISynapseReactor
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
# Cast safety: Twisted does some naughty magic which replaces the
|
||||||
|
# twisted.internet.reactor module with a Reactor instance at runtime.
|
||||||
|
reactor = cast(ISynapseReactor, reactor_)
|
||||||
logger = logging.getLogger("synapse_port_db")
|
logger = logging.getLogger("synapse_port_db")
|
||||||
|
|
||||||
|
|
||||||
|
@ -159,12 +164,14 @@ IGNORED_TABLES = {
|
||||||
|
|
||||||
# Error returned by the run function. Used at the top-level part of the script to
|
# Error returned by the run function. Used at the top-level part of the script to
|
||||||
# handle errors and return codes.
|
# handle errors and return codes.
|
||||||
end_error = None # type: Optional[str]
|
end_error: Optional[str] = None
|
||||||
# The exec_info for the error, if any. If error is defined but not exec_info the script
|
# The exec_info for the error, if any. If error is defined but not exec_info the script
|
||||||
# will show only the error message without the stacktrace, if exec_info is defined but
|
# will show only the error message without the stacktrace, if exec_info is defined but
|
||||||
# not the error then the script will show nothing outside of what's printed in the run
|
# not the error then the script will show nothing outside of what's printed in the run
|
||||||
# function. If both are defined, the script will print both the error and the stacktrace.
|
# function. If both are defined, the script will print both the error and the stacktrace.
|
||||||
end_error_exec_info = None
|
end_error_exec_info: Optional[
|
||||||
|
Tuple[Type[BaseException], BaseException, TracebackType]
|
||||||
|
] = None
|
||||||
|
|
||||||
|
|
||||||
class Store(
|
class Store(
|
||||||
|
@ -236,9 +243,12 @@ class MockHomeserver:
|
||||||
return "master"
|
return "master"
|
||||||
|
|
||||||
|
|
||||||
class Porter(object):
|
class Porter:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, sqlite_config, progress, batch_size, hs_config):
|
||||||
self.__dict__.update(kwargs)
|
self.sqlite_config = sqlite_config
|
||||||
|
self.progress = progress
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.hs_config = hs_config
|
||||||
|
|
||||||
async def setup_table(self, table):
|
async def setup_table(self, table):
|
||||||
if table in APPEND_ONLY_TABLES:
|
if table in APPEND_ONLY_TABLES:
|
||||||
|
@ -323,7 +333,7 @@ class Porter(object):
|
||||||
"""
|
"""
|
||||||
txn.execute(sql)
|
txn.execute(sql)
|
||||||
|
|
||||||
results = {}
|
results: Dict[str, Set[str]] = {}
|
||||||
for table, foreign_table in txn:
|
for table, foreign_table in txn:
|
||||||
results.setdefault(table, set()).add(foreign_table)
|
results.setdefault(table, set()).add(foreign_table)
|
||||||
return results
|
return results
|
||||||
|
@ -540,7 +550,8 @@ class Porter(object):
|
||||||
db_conn, allow_outdated_version=allow_outdated_version
|
db_conn, allow_outdated_version=allow_outdated_version
|
||||||
)
|
)
|
||||||
prepare_database(db_conn, engine, config=self.hs_config)
|
prepare_database(db_conn, engine, config=self.hs_config)
|
||||||
store = Store(DatabasePool(hs, db_config, engine), db_conn, hs)
|
# Type safety: ignore that we're using Mock homeservers here.
|
||||||
|
store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) # type: ignore[arg-type]
|
||||||
db_conn.commit()
|
db_conn.commit()
|
||||||
|
|
||||||
return store
|
return store
|
||||||
|
@ -724,7 +735,9 @@ class Porter(object):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
global end_error_exec_info
|
global end_error_exec_info
|
||||||
end_error = str(e)
|
end_error = str(e)
|
||||||
end_error_exec_info = sys.exc_info()
|
# Type safety: we're in an exception handler, so the exc_info() tuple
|
||||||
|
# will not be (None, None, None).
|
||||||
|
end_error_exec_info = sys.exc_info() # type: ignore[assignment]
|
||||||
logger.exception("")
|
logger.exception("")
|
||||||
finally:
|
finally:
|
||||||
reactor.stop()
|
reactor.stop()
|
||||||
|
@ -1023,7 +1036,7 @@ class CursesProgress(Progress):
|
||||||
curses.init_pair(1, curses.COLOR_RED, -1)
|
curses.init_pair(1, curses.COLOR_RED, -1)
|
||||||
curses.init_pair(2, curses.COLOR_GREEN, -1)
|
curses.init_pair(2, curses.COLOR_GREEN, -1)
|
||||||
|
|
||||||
self.last_update = 0
|
self.last_update = 0.0
|
||||||
|
|
||||||
self.finished = False
|
self.finished = False
|
||||||
|
|
||||||
|
@ -1082,8 +1095,7 @@ class CursesProgress(Progress):
|
||||||
left_margin = 5
|
left_margin = 5
|
||||||
middle_space = 1
|
middle_space = 1
|
||||||
|
|
||||||
items = self.tables.items()
|
items = sorted(self.tables.items(), key=lambda i: (i[1]["perc"], i[0]))
|
||||||
items = sorted(items, key=lambda i: (i[1]["perc"], i[0]))
|
|
||||||
|
|
||||||
for i, (table, data) in enumerate(items):
|
for i, (table, data) in enumerate(items):
|
||||||
if i + 2 >= rows:
|
if i + 2 >= rows:
|
||||||
|
@ -1179,15 +1191,11 @@ def main():
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
logging_config = {
|
logging.basicConfig(
|
||||||
"level": logging.DEBUG if args.v else logging.INFO,
|
level=logging.DEBUG if args.v else logging.INFO,
|
||||||
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
|
format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
|
||||||
}
|
filename="port-synapse.log" if args.curses else None,
|
||||||
|
)
|
||||||
if args.curses:
|
|
||||||
logging_config["filename"] = "port-synapse.log"
|
|
||||||
|
|
||||||
logging.basicConfig(**logging_config)
|
|
||||||
|
|
||||||
sqlite_config = {
|
sqlite_config = {
|
||||||
"name": "sqlite3",
|
"name": "sqlite3",
|
||||||
|
@ -1218,6 +1226,7 @@ def main():
|
||||||
config.parse_config_dict(hs_config, "", "")
|
config.parse_config_dict(hs_config, "", "")
|
||||||
|
|
||||||
def start(stdscr=None):
|
def start(stdscr=None):
|
||||||
|
progress: Progress
|
||||||
if stdscr:
|
if stdscr:
|
||||||
progress = CursesProgress(stdscr)
|
progress = CursesProgress(stdscr)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -16,22 +16,27 @@
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from matrix_common.versionstring import get_distribution_version_string
|
from matrix_common.versionstring import get_distribution_version_string
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor as reactor_
|
||||||
|
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage import DataStore
|
from synapse.storage import DataStore
|
||||||
|
from synapse.types import ISynapseReactor
|
||||||
|
|
||||||
|
# Cast safety: Twisted does some naughty magic which replaces the
|
||||||
|
# twisted.internet.reactor module with a Reactor instance at runtime.
|
||||||
|
reactor = cast(ISynapseReactor, reactor_)
|
||||||
logger = logging.getLogger("update_database")
|
logger = logging.getLogger("update_database")
|
||||||
|
|
||||||
|
|
||||||
class MockHomeserver(HomeServer):
|
class MockHomeserver(HomeServer):
|
||||||
DATASTORE_CLASS = DataStore
|
DATASTORE_CLASS = DataStore # type: ignore [assignment]
|
||||||
|
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super(MockHomeserver, self).__init__(
|
super(MockHomeserver, self).__init__(
|
||||||
|
@ -85,12 +90,10 @@ def main():
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
logging_config = {
|
logging.basicConfig(
|
||||||
"level": logging.DEBUG if args.v else logging.INFO,
|
level=logging.DEBUG if args.v else logging.INFO,
|
||||||
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
|
format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
|
||||||
}
|
)
|
||||||
|
|
||||||
logging.basicConfig(**logging_config)
|
|
||||||
|
|
||||||
# Load, process and sanity-check the config.
|
# Load, process and sanity-check the config.
|
||||||
hs_config = yaml.safe_load(args.database_config)
|
hs_config = yaml.safe_load(args.database_config)
|
||||||
|
|
Loading…
Reference in a new issue