Add strict typing for recorder util (#68681)

This commit is contained in:
J. Nick Koston 2022-03-25 12:03:46 -10:00 committed by GitHub
parent 4dc8aff3d5
commit 225f7a989b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 22 deletions

View File

@ -168,6 +168,7 @@ homeassistant.components.recorder.history
homeassistant.components.recorder.purge
homeassistant.components.recorder.repack
homeassistant.components.recorder.statistics
homeassistant.components.recorder.util
homeassistant.components.remote.*
homeassistant.components.renault.*
homeassistant.components.ridwell.*

View File

@ -3,12 +3,12 @@ from __future__ import annotations
from collections.abc import Callable, Generator
from contextlib import contextmanager
from datetime import timedelta
from datetime import datetime, timedelta
import functools
import logging
import os
import time
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from awesomeversion import (
AwesomeVersion,
@ -16,6 +16,7 @@ from awesomeversion import (
AwesomeVersionStrategy,
)
from sqlalchemy import text
from sqlalchemy.engine.cursor import CursorFetchStrategy
from sqlalchemy.exc import OperationalError, SQLAlchemyError
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
@ -95,7 +96,7 @@ def session_scope(
session.close()
def commit(session, work):
def commit(session: Session, work: Any) -> bool:
"""Commit & retry work: Either a model or in a function."""
for _ in range(0, RETRIES):
try:
@ -175,12 +176,12 @@ def validate_or_move_away_sqlite_database(dburl: str) -> bool:
return True
def dburl_to_path(dburl):
def dburl_to_path(dburl: str) -> str:
"""Convert the db url into a filesystem path."""
return dburl[len(SQLITE_URL_PREFIX) :]
def last_run_was_recently_clean(cursor):
def last_run_was_recently_clean(cursor: CursorFetchStrategy) -> bool:
"""Verify the last recorder run was recently clean."""
cursor.execute("SELECT end FROM recorder_runs ORDER BY start DESC LIMIT 1;")
@ -190,6 +191,7 @@ def last_run_was_recently_clean(cursor):
return False
last_run_end_time = process_timestamp(dt_util.parse_datetime(end_time[0]))
assert last_run_end_time is not None
now = dt_util.utcnow()
_LOGGER.debug("The last run ended at: %s (now: %s)", last_run_end_time, now)
@ -200,7 +202,7 @@ def last_run_was_recently_clean(cursor):
return True
def basic_sanity_check(cursor):
def basic_sanity_check(cursor: CursorFetchStrategy) -> bool:
"""Check tables to make sure select does not fail."""
for table in ALL_TABLES:
@ -235,7 +237,7 @@ def validate_sqlite_database(dbpath: str) -> bool:
return True
def run_checks_on_open_db(dbpath, cursor):
def run_checks_on_open_db(dbpath: str, cursor: CursorFetchStrategy) -> None:
"""Run checks that will generate a sqlite3 exception if there is corruption."""
sanity_check_passed = basic_sanity_check(cursor)
last_run_was_clean = last_run_was_recently_clean(cursor)
@ -278,14 +280,14 @@ def move_away_broken_database(dbfile: str) -> None:
os.rename(path, f"{path}{corrupt_postfix}")
def execute_on_connection(dbapi_connection, statement):
def execute_on_connection(dbapi_connection: Any, statement: str) -> None:
"""Execute a single statement with a dbapi connection."""
cursor = dbapi_connection.cursor()
cursor.execute(statement)
cursor.close()
def query_on_connection(dbapi_connection, statement):
def query_on_connection(dbapi_connection: Any, statement: str) -> Any:
"""Execute a single statement with a dbapi connection and return the result."""
cursor = dbapi_connection.cursor()
cursor.execute(statement)
@ -294,30 +296,34 @@ def query_on_connection(dbapi_connection, statement):
return result
def _warn_unsupported_dialect(dialect):
def _warn_unsupported_dialect(dialect_name: str) -> None:
"""Warn about unsupported database version."""
_LOGGER.warning(
"Database %s is not supported; Home Assistant supports %s. "
"Starting with Home Assistant 2022.2 this will prevent the recorder from "
"starting. Please migrate your database to a supported software before then",
dialect,
dialect_name,
"MariaDB ≥ 10.3, MySQL ≥ 8.0, PostgreSQL ≥ 12, SQLite ≥ 3.31.0",
)
def _warn_unsupported_version(server_version, dialect, minimum_version):
def _warn_unsupported_version(
server_version: str, dialect_name: str, minimum_version: str
) -> None:
"""Warn about unsupported database version."""
_LOGGER.warning(
"Version %s of %s is not supported; minimum supported version is %s. "
"Starting with Home Assistant 2022.2 this will prevent the recorder from "
"starting. Please upgrade your database software before then",
server_version,
dialect,
dialect_name,
minimum_version,
)
def _extract_version_from_server_response(server_response):
def _extract_version_from_server_response(
server_response: str,
) -> AwesomeVersion | None:
"""Attempt to extract version from server response."""
try:
return AwesomeVersion(
@ -330,8 +336,11 @@ def _extract_version_from_server_response(server_response):
def setup_connection_for_dialect(
instance, dialect_name, dbapi_connection, first_connection
):
instance: Recorder,
dialect_name: str,
dbapi_connection: Any,
first_connection: bool,
) -> None:
"""Execute statements needed for dialect connection."""
# Returns False if the the connection needs to be setup
# on the next connection, returns True if the connection
@ -406,7 +415,7 @@ def setup_connection_for_dialect(
_warn_unsupported_dialect(dialect_name)
def end_incomplete_runs(session, start_time):
def end_incomplete_runs(session: Session, start_time: datetime) -> None:
"""End any incomplete recorder runs."""
for run in session.query(RecorderRuns).filter_by(end=None):
run.closed_incorrect = True
@ -423,9 +432,9 @@ def retryable_database_job(description: str) -> Callable:
The job should return True if it finished, and False if it needs to be rescheduled.
"""
def decorator(job: Callable) -> Callable:
def decorator(job: Callable[[Any], bool]) -> Callable:
@functools.wraps(job)
def wrapper(instance: Recorder, *args, **kwargs):
def wrapper(instance: Recorder, *args: Any, **kwargs: Any) -> bool:
try:
return job(instance, *args, **kwargs)
except OperationalError as err:
@ -451,7 +460,7 @@ def retryable_database_job(description: str) -> Callable:
return decorator
def perodic_db_cleanups(instance: Recorder):
def perodic_db_cleanups(instance: Recorder) -> None:
"""Run any database cleanups that need to happen perodiclly.
These cleanups will happen nightly or after any purge.
@ -465,7 +474,7 @@ def perodic_db_cleanups(instance: Recorder):
@contextmanager
def write_lock_db_sqlite(instance: Recorder):
def write_lock_db_sqlite(instance: Recorder) -> Generator[None, None, None]:
"""Lock database for writes."""
assert instance.engine is not None
with instance.engine.connect() as connection:
@ -490,4 +499,5 @@ def async_migration_in_progress(hass: HomeAssistant) -> bool:
"""
if DATA_INSTANCE not in hass.data:
return False
return hass.data[DATA_INSTANCE].migration_in_progress
instance: Recorder = hass.data[DATA_INSTANCE]
return instance.migration_in_progress

View File

@ -1650,6 +1650,17 @@ no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.util]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.remote.*]
check_untyped_defs = true
disallow_incomplete_defs = true