0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-15 22:42:23 +01:00

Update type hints for Cursor to match PEP 249. (#9299)

This commit is contained in:
Jonathan de Jong 2021-02-05 21:39:19 +01:00 committed by GitHub
parent 5a9cdaa6e9
commit d882fbca38
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 48 additions and 18 deletions

1
changelog.d/9299.misc Normal file
View file

@ -0,0 +1 @@
Update the `Cursor` type hints to better match PEP 249.

View file

@ -158,8 +158,8 @@ class LoggingDatabaseConnection:
def commit(self) -> None: def commit(self) -> None:
self.conn.commit() self.conn.commit()
def rollback(self, *args, **kwargs) -> None: def rollback(self) -> None:
self.conn.rollback(*args, **kwargs) self.conn.rollback()
def __enter__(self) -> "Connection": def __enter__(self) -> "Connection":
self.conn.__enter__() self.conn.__enter__()
@ -244,12 +244,15 @@ class LoggingTransaction:
assert self.exception_callbacks is not None assert self.exception_callbacks is not None
self.exception_callbacks.append((callback, args, kwargs)) self.exception_callbacks.append((callback, args, kwargs))
def fetchone(self) -> Optional[Tuple]:
return self.txn.fetchone()
def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
return self.txn.fetchmany(size=size)
def fetchall(self) -> List[Tuple]: def fetchall(self) -> List[Tuple]:
return self.txn.fetchall() return self.txn.fetchall()
def fetchone(self) -> Tuple:
return self.txn.fetchone()
def __iter__(self) -> Iterator[Tuple]: def __iter__(self) -> Iterator[Tuple]:
return self.txn.__iter__() return self.txn.__iter__()
@ -754,6 +757,7 @@ class DatabasePool:
Returns: Returns:
A list of dicts where the key is the column header. A list of dicts where the key is the column header.
""" """
assert cursor.description is not None, "cursor.description was None"
col_headers = [intern(str(column[0])) for column in cursor.description] col_headers = [intern(str(column[0])) for column in cursor.description]
results = [dict(zip(col_headers, row)) for row in cursor] results = [dict(zip(col_headers, row)) for row in cursor]
return results return results

View file

@ -619,9 +619,9 @@ def _get_or_create_schema_state(
txn.execute("SELECT version, upgraded FROM schema_version") txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone() row = txn.fetchone()
current_version = int(row[0]) if row else None
if current_version: if row is not None:
current_version = int(row[0])
txn.execute( txn.execute(
"SELECT file FROM applied_schema_deltas WHERE version >= ?", "SELECT file FROM applied_schema_deltas WHERE version >= ?",
(current_version,), (current_version,),

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Iterable, Iterator, List, Optional, Tuple from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
from typing_extensions import Protocol from typing_extensions import Protocol
@ -20,23 +20,44 @@ from typing_extensions import Protocol
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249 Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
""" """
_Parameters = Union[Sequence[Any], Mapping[str, Any]]
class Cursor(Protocol): class Cursor(Protocol):
def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any: def execute(self, sql: str, parameters: _Parameters = ...) -> Any:
... ...
def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any: def executemany(self, sql: str, parameters: Sequence[_Parameters]) -> Any:
...
def fetchone(self) -> Optional[Tuple]:
...
def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]:
... ...
def fetchall(self) -> List[Tuple]: def fetchall(self) -> List[Tuple]:
... ...
def fetchone(self) -> Tuple:
...
@property @property
def description(self) -> Any: def description(
return None self,
) -> Optional[
Sequence[
# Note that this is an approximate typing based on sqlite3 and other
# drivers, and may not be entirely accurate.
Tuple[
str,
Optional[Any],
Optional[int],
Optional[int],
Optional[int],
Optional[int],
Optional[int],
]
]
]:
...
@property @property
def rowcount(self) -> int: def rowcount(self) -> int:
@ -59,7 +80,7 @@ class Connection(Protocol):
def commit(self) -> None: def commit(self) -> None:
... ...
def rollback(self, *args, **kwargs) -> None: def rollback(self) -> None:
... ...
def __enter__(self) -> "Connection": def __enter__(self) -> "Connection":

View file

@ -106,7 +106,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
def get_next_id_txn(self, txn: Cursor) -> int: def get_next_id_txn(self, txn: Cursor) -> int:
txn.execute("SELECT nextval(?)", (self._sequence_name,)) txn.execute("SELECT nextval(?)", (self._sequence_name,))
return txn.fetchone()[0] fetch_res = txn.fetchone()
assert fetch_res is not None
return fetch_res[0]
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
txn.execute( txn.execute(
@ -147,7 +149,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
txn.execute( txn.execute(
"SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name} "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
) )
last_value, is_called = txn.fetchone() fetch_res = txn.fetchone()
assert fetch_res is not None
last_value, is_called = fetch_res
# If we have an associated stream check the stream_positions table. # If we have an associated stream check the stream_positions table.
max_in_stream_positions = None max_in_stream_positions = None