mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-15 22:42:23 +01:00
Update type hints for Cursor to match PEP 249. (#9299)
This commit is contained in:
parent
5a9cdaa6e9
commit
d882fbca38
5 changed files with 48 additions and 18 deletions
1
changelog.d/9299.misc
Normal file
1
changelog.d/9299.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Update the `Cursor` type hints to better match PEP 249.
|
|
@ -158,8 +158,8 @@ class LoggingDatabaseConnection:
|
||||||
def commit(self) -> None:
|
def commit(self) -> None:
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
def rollback(self, *args, **kwargs) -> None:
|
def rollback(self) -> None:
|
||||||
self.conn.rollback(*args, **kwargs)
|
self.conn.rollback()
|
||||||
|
|
||||||
def __enter__(self) -> "Connection":
|
def __enter__(self) -> "Connection":
|
||||||
self.conn.__enter__()
|
self.conn.__enter__()
|
||||||
|
@ -244,12 +244,15 @@ class LoggingTransaction:
|
||||||
assert self.exception_callbacks is not None
|
assert self.exception_callbacks is not None
|
||||||
self.exception_callbacks.append((callback, args, kwargs))
|
self.exception_callbacks.append((callback, args, kwargs))
|
||||||
|
|
||||||
|
def fetchone(self) -> Optional[Tuple]:
|
||||||
|
return self.txn.fetchone()
|
||||||
|
|
||||||
|
def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
|
||||||
|
return self.txn.fetchmany(size=size)
|
||||||
|
|
||||||
def fetchall(self) -> List[Tuple]:
|
def fetchall(self) -> List[Tuple]:
|
||||||
return self.txn.fetchall()
|
return self.txn.fetchall()
|
||||||
|
|
||||||
def fetchone(self) -> Tuple:
|
|
||||||
return self.txn.fetchone()
|
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[Tuple]:
|
def __iter__(self) -> Iterator[Tuple]:
|
||||||
return self.txn.__iter__()
|
return self.txn.__iter__()
|
||||||
|
|
||||||
|
@ -754,6 +757,7 @@ class DatabasePool:
|
||||||
Returns:
|
Returns:
|
||||||
A list of dicts where the key is the column header.
|
A list of dicts where the key is the column header.
|
||||||
"""
|
"""
|
||||||
|
assert cursor.description is not None, "cursor.description was None"
|
||||||
col_headers = [intern(str(column[0])) for column in cursor.description]
|
col_headers = [intern(str(column[0])) for column in cursor.description]
|
||||||
results = [dict(zip(col_headers, row)) for row in cursor]
|
results = [dict(zip(col_headers, row)) for row in cursor]
|
||||||
return results
|
return results
|
||||||
|
|
|
@ -619,9 +619,9 @@ def _get_or_create_schema_state(
|
||||||
|
|
||||||
txn.execute("SELECT version, upgraded FROM schema_version")
|
txn.execute("SELECT version, upgraded FROM schema_version")
|
||||||
row = txn.fetchone()
|
row = txn.fetchone()
|
||||||
current_version = int(row[0]) if row else None
|
|
||||||
|
|
||||||
if current_version:
|
if row is not None:
|
||||||
|
current_version = int(row[0])
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"SELECT file FROM applied_schema_deltas WHERE version >= ?",
|
"SELECT file FROM applied_schema_deltas WHERE version >= ?",
|
||||||
(current_version,),
|
(current_version,),
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Iterable, Iterator, List, Optional, Tuple
|
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
|
@ -20,23 +20,44 @@ from typing_extensions import Protocol
|
||||||
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
|
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_Parameters = Union[Sequence[Any], Mapping[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
class Cursor(Protocol):
|
class Cursor(Protocol):
|
||||||
def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
|
def execute(self, sql: str, parameters: _Parameters = ...) -> Any:
|
||||||
...
|
...
|
||||||
|
|
||||||
def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
|
def executemany(self, sql: str, parameters: Sequence[_Parameters]) -> Any:
|
||||||
|
...
|
||||||
|
|
||||||
|
def fetchone(self) -> Optional[Tuple]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]:
|
||||||
...
|
...
|
||||||
|
|
||||||
def fetchall(self) -> List[Tuple]:
|
def fetchall(self) -> List[Tuple]:
|
||||||
...
|
...
|
||||||
|
|
||||||
def fetchone(self) -> Tuple:
|
|
||||||
...
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> Any:
|
def description(
|
||||||
return None
|
self,
|
||||||
|
) -> Optional[
|
||||||
|
Sequence[
|
||||||
|
# Note that this is an approximate typing based on sqlite3 and other
|
||||||
|
# drivers, and may not be entirely accurate.
|
||||||
|
Tuple[
|
||||||
|
str,
|
||||||
|
Optional[Any],
|
||||||
|
Optional[int],
|
||||||
|
Optional[int],
|
||||||
|
Optional[int],
|
||||||
|
Optional[int],
|
||||||
|
Optional[int],
|
||||||
|
]
|
||||||
|
]
|
||||||
|
]:
|
||||||
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rowcount(self) -> int:
|
def rowcount(self) -> int:
|
||||||
|
@ -59,7 +80,7 @@ class Connection(Protocol):
|
||||||
def commit(self) -> None:
|
def commit(self) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
def rollback(self, *args, **kwargs) -> None:
|
def rollback(self) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
def __enter__(self) -> "Connection":
|
def __enter__(self) -> "Connection":
|
||||||
|
|
|
@ -106,7 +106,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
|
||||||
|
|
||||||
def get_next_id_txn(self, txn: Cursor) -> int:
|
def get_next_id_txn(self, txn: Cursor) -> int:
|
||||||
txn.execute("SELECT nextval(?)", (self._sequence_name,))
|
txn.execute("SELECT nextval(?)", (self._sequence_name,))
|
||||||
return txn.fetchone()[0]
|
fetch_res = txn.fetchone()
|
||||||
|
assert fetch_res is not None
|
||||||
|
return fetch_res[0]
|
||||||
|
|
||||||
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
|
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
|
@ -147,7 +149,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
|
"SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
|
||||||
)
|
)
|
||||||
last_value, is_called = txn.fetchone()
|
fetch_res = txn.fetchone()
|
||||||
|
assert fetch_res is not None
|
||||||
|
last_value, is_called = fetch_res
|
||||||
|
|
||||||
# If we have an associated stream check the stream_positions table.
|
# If we have an associated stream check the stream_positions table.
|
||||||
max_in_stream_positions = None
|
max_in_stream_positions = None
|
||||||
|
|
Loading…
Reference in a new issue