Replace asyncio.iscoroutinefunction (#148738)

This commit is contained in:
Marc Mueller 2025-07-14 21:24:32 +02:00 committed by GitHub
parent 9e3a78b7ef
commit 80eb4fb2f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 22 additions and 18 deletions

View File

@ -2,9 +2,9 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from functools import wraps from functools import wraps
import inspect
from typing import TYPE_CHECKING, Any, Final, overload from typing import TYPE_CHECKING, Any, Final, overload
import knx_frontend as knx_panel import knx_frontend as knx_panel
@ -116,7 +116,7 @@ def provide_knx(
"KNX integration not loaded.", "KNX integration not loaded.",
) )
if asyncio.iscoroutinefunction(func): if inspect.iscoroutinefunction(func):
@wraps(func) @wraps(func)
async def with_knx( async def with_knx(

View File

@ -384,7 +384,7 @@ def get_hassjob_callable_job_type(target: Callable[..., Any]) -> HassJobType:
while isinstance(check_target, functools.partial): while isinstance(check_target, functools.partial):
check_target = check_target.func check_target = check_target.func
if asyncio.iscoroutinefunction(check_target): if inspect.iscoroutinefunction(check_target):
return HassJobType.Coroutinefunction return HassJobType.Coroutinefunction
if is_callback(check_target): if is_callback(check_target):
return HassJobType.Callback return HassJobType.Callback

View File

@ -3,12 +3,12 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import asyncio
from collections import deque from collections import deque
from collections.abc import Callable, Container, Coroutine, Generator, Iterable from collections.abc import Callable, Container, Coroutine, Generator, Iterable
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime, time as dt_time, timedelta from datetime import datetime, time as dt_time, timedelta
import functools as ft import functools as ft
import inspect
import logging import logging
import re import re
import sys import sys
@ -359,7 +359,7 @@ async def async_from_config(
while isinstance(check_factory, ft.partial): while isinstance(check_factory, ft.partial):
check_factory = check_factory.func check_factory = check_factory.func
if asyncio.iscoroutinefunction(check_factory): if inspect.iscoroutinefunction(check_factory):
return cast(ConditionCheckerType, await factory(hass, config)) return cast(ConditionCheckerType, await factory(hass, config))
return cast(ConditionCheckerType, factory(config)) return cast(ConditionCheckerType, factory(config))

View File

@ -2,11 +2,11 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
import enum import enum
import functools import functools
import inspect
import linecache import linecache
import logging import logging
import sys import sys
@ -397,7 +397,7 @@ def _report_usage_no_integration(
def warn_use[_CallableT: Callable](func: _CallableT, what: str) -> _CallableT: def warn_use[_CallableT: Callable](func: _CallableT, what: str) -> _CallableT:
"""Mock a function to warn when it was about to be used.""" """Mock a function to warn when it was about to be used."""
if asyncio.iscoroutinefunction(func): if inspect.iscoroutinefunction(func):
@functools.wraps(func) @functools.wraps(func)
async def report_use(*args: Any, **kwargs: Any) -> None: async def report_use(*args: Any, **kwargs: Any) -> None:

View File

@ -2,10 +2,10 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from contextvars import ContextVar from contextvars import ContextVar
from http import HTTPStatus from http import HTTPStatus
import inspect
import logging import logging
from typing import Any, Final from typing import Any, Final
@ -45,7 +45,7 @@ def request_handler_factory(
hass: HomeAssistant, view: HomeAssistantView, handler: Callable hass: HomeAssistant, view: HomeAssistantView, handler: Callable
) -> Callable[[web.Request], Awaitable[web.StreamResponse]]: ) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
"""Wrap the handler classes.""" """Wrap the handler classes."""
is_coroutinefunction = asyncio.iscoroutinefunction(handler) is_coroutinefunction = inspect.iscoroutinefunction(handler)
assert is_coroutinefunction or is_callback(handler), ( assert is_coroutinefunction or is_callback(handler), (
"Handler should be a coroutine or a callback." "Handler should be a coroutine or a callback."
) )

View File

@ -7,6 +7,7 @@ from collections.abc import Callable, Coroutine, Iterable
import dataclasses import dataclasses
from enum import Enum from enum import Enum
from functools import cache, partial from functools import cache, partial
import inspect
import logging import logging
from types import ModuleType from types import ModuleType
from typing import TYPE_CHECKING, Any, TypedDict, cast, override from typing import TYPE_CHECKING, Any, TypedDict, cast, override
@ -997,7 +998,7 @@ def verify_domain_control(
service_handler: Callable[[ServiceCall], Any], service_handler: Callable[[ServiceCall], Any],
) -> Callable[[ServiceCall], Any]: ) -> Callable[[ServiceCall], Any]:
"""Decorate.""" """Decorate."""
if not asyncio.iscoroutinefunction(service_handler): if not inspect.iscoroutinefunction(service_handler):
raise HomeAssistantError("Can only decorate async functions.") raise HomeAssistantError("Can only decorate async functions.")
async def check_permissions(call: ServiceCall) -> Any: async def check_permissions(call: ServiceCall) -> Any:

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
import functools import functools
import inspect
from typing import Any, Literal, assert_type, cast, overload from typing import Any, Literal, assert_type, cast, overload
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -47,7 +48,7 @@ def singleton[_S, _T, _U](
def wrapper(func: _FuncType[_Coro[_T] | _U]) -> _FuncType[_Coro[_T] | _U]: def wrapper(func: _FuncType[_Coro[_T] | _U]) -> _FuncType[_Coro[_T] | _U]:
"""Wrap a function with caching logic.""" """Wrap a function with caching logic."""
if not asyncio.iscoroutinefunction(func): if not inspect.iscoroutinefunction(func):
@functools.lru_cache(maxsize=1) @functools.lru_cache(maxsize=1)
@bind_hass @bind_hass

View File

@ -8,6 +8,7 @@ from collections import defaultdict
from collections.abc import Callable, Coroutine, Iterable from collections.abc import Callable, Coroutine, Iterable
from dataclasses import dataclass, field from dataclasses import dataclass, field
import functools import functools
import inspect
import logging import logging
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast
@ -407,7 +408,7 @@ def _trigger_action_wrapper(
check_func = check_func.func check_func = check_func.func
wrapper_func: Callable[..., Any] | Callable[..., Coroutine[Any, Any, Any]] wrapper_func: Callable[..., Any] | Callable[..., Coroutine[Any, Any, Any]]
if asyncio.iscoroutinefunction(check_func): if inspect.iscoroutinefunction(check_func):
async_action = cast(Callable[..., Coroutine[Any, Any, Any]], action) async_action = cast(Callable[..., Coroutine[Any, Any, Any]], action)
@functools.wraps(async_action) @functools.wraps(async_action)

View File

@ -2,10 +2,10 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine, Iterable, KeysView, Mapping from collections.abc import Callable, Coroutine, Iterable, KeysView, Mapping
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import wraps from functools import wraps
import inspect
import random import random
import re import re
import string import string
@ -125,7 +125,7 @@ class Throttle:
def __call__(self, method: Callable) -> Callable: def __call__(self, method: Callable) -> Callable:
"""Caller for the throttle.""" """Caller for the throttle."""
# Make sure we return a coroutine if the method is async. # Make sure we return a coroutine if the method is async.
if asyncio.iscoroutinefunction(method): if inspect.iscoroutinefunction(method):
async def throttled_value() -> None: async def throttled_value() -> None:
"""Stand-in function for when real func is being throttled.""" """Stand-in function for when real func is being throttled."""

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import inspect
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
@ -191,7 +191,7 @@ async def trigger_subscription_callback(
object_id=object_id, object_id=object_id,
data=data, data=data,
) )
if asyncio.iscoroutinefunction(cb_func): if inspect.iscoroutinefunction(cb_func):
await cb_func(event) await cb_func(event)
else: else:
cb_func(event) cb_func(event)

View File

@ -2,6 +2,7 @@
import asyncio import asyncio
from functools import partial from functools import partial
import inspect
import logging import logging
import queue import queue
from unittest.mock import patch from unittest.mock import patch
@ -102,7 +103,7 @@ def test_catch_log_exception() -> None:
async def async_meth(): async def async_meth():
pass pass
assert asyncio.iscoroutinefunction( assert inspect.iscoroutinefunction(
logging_util.catch_log_exception(partial(async_meth), lambda: None) logging_util.catch_log_exception(partial(async_meth), lambda: None)
) )
@ -120,7 +121,7 @@ def test_catch_log_exception() -> None:
wrapped = logging_util.catch_log_exception(partial(sync_meth), lambda: None) wrapped = logging_util.catch_log_exception(partial(sync_meth), lambda: None)
assert not is_callback(wrapped) assert not is_callback(wrapped)
assert not asyncio.iscoroutinefunction(wrapped) assert not inspect.iscoroutinefunction(wrapped)
@pytest.mark.no_fail_on_log_exception @pytest.mark.no_fail_on_log_exception