Replace asyncio.iscoroutinefunction (#148738)

This commit is contained in:
Marc Mueller 2025-07-14 21:24:32 +02:00 committed by GitHub
parent 9e3a78b7ef
commit 80eb4fb2f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 22 additions and 18 deletions

View File

@ -2,9 +2,9 @@
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from functools import wraps
import inspect
from typing import TYPE_CHECKING, Any, Final, overload
import knx_frontend as knx_panel
@ -116,7 +116,7 @@ def provide_knx(
"KNX integration not loaded.",
)
if asyncio.iscoroutinefunction(func):
if inspect.iscoroutinefunction(func):
@wraps(func)
async def with_knx(

View File

@ -384,7 +384,7 @@ def get_hassjob_callable_job_type(target: Callable[..., Any]) -> HassJobType:
while isinstance(check_target, functools.partial):
check_target = check_target.func
if asyncio.iscoroutinefunction(check_target):
if inspect.iscoroutinefunction(check_target):
return HassJobType.Coroutinefunction
if is_callback(check_target):
return HassJobType.Callback

View File

@ -3,12 +3,12 @@
from __future__ import annotations
import abc
import asyncio
from collections import deque
from collections.abc import Callable, Container, Coroutine, Generator, Iterable
from contextlib import contextmanager
from datetime import datetime, time as dt_time, timedelta
import functools as ft
import inspect
import logging
import re
import sys
@ -359,7 +359,7 @@ async def async_from_config(
while isinstance(check_factory, ft.partial):
check_factory = check_factory.func
if asyncio.iscoroutinefunction(check_factory):
if inspect.iscoroutinefunction(check_factory):
return cast(ConditionCheckerType, await factory(hass, config))
return cast(ConditionCheckerType, factory(config))

View File

@ -2,11 +2,11 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
from dataclasses import dataclass
import enum
import functools
import inspect
import linecache
import logging
import sys
@ -397,7 +397,7 @@ def _report_usage_no_integration(
def warn_use[_CallableT: Callable](func: _CallableT, what: str) -> _CallableT:
"""Mock a function to warn when it was about to be used."""
if asyncio.iscoroutinefunction(func):
if inspect.iscoroutinefunction(func):
@functools.wraps(func)
async def report_use(*args: Any, **kwargs: Any) -> None:

View File

@ -2,10 +2,10 @@
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from contextvars import ContextVar
from http import HTTPStatus
import inspect
import logging
from typing import Any, Final
@ -45,7 +45,7 @@ def request_handler_factory(
hass: HomeAssistant, view: HomeAssistantView, handler: Callable
) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
"""Wrap the handler classes."""
is_coroutinefunction = asyncio.iscoroutinefunction(handler)
is_coroutinefunction = inspect.iscoroutinefunction(handler)
assert is_coroutinefunction or is_callback(handler), (
"Handler should be a coroutine or a callback."
)

View File

@ -7,6 +7,7 @@ from collections.abc import Callable, Coroutine, Iterable
import dataclasses
from enum import Enum
from functools import cache, partial
import inspect
import logging
from types import ModuleType
from typing import TYPE_CHECKING, Any, TypedDict, cast, override
@ -997,7 +998,7 @@ def verify_domain_control(
service_handler: Callable[[ServiceCall], Any],
) -> Callable[[ServiceCall], Any]:
"""Decorate."""
if not asyncio.iscoroutinefunction(service_handler):
if not inspect.iscoroutinefunction(service_handler):
raise HomeAssistantError("Can only decorate async functions.")
async def check_permissions(call: ServiceCall) -> Any:

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine
import functools
import inspect
from typing import Any, Literal, assert_type, cast, overload
from homeassistant.core import HomeAssistant
@ -47,7 +48,7 @@ def singleton[_S, _T, _U](
def wrapper(func: _FuncType[_Coro[_T] | _U]) -> _FuncType[_Coro[_T] | _U]:
"""Wrap a function with caching logic."""
if not asyncio.iscoroutinefunction(func):
if not inspect.iscoroutinefunction(func):
@functools.lru_cache(maxsize=1)
@bind_hass

View File

@ -8,6 +8,7 @@ from collections import defaultdict
from collections.abc import Callable, Coroutine, Iterable
from dataclasses import dataclass, field
import functools
import inspect
import logging
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, cast
@ -407,7 +408,7 @@ def _trigger_action_wrapper(
check_func = check_func.func
wrapper_func: Callable[..., Any] | Callable[..., Coroutine[Any, Any, Any]]
if asyncio.iscoroutinefunction(check_func):
if inspect.iscoroutinefunction(check_func):
async_action = cast(Callable[..., Coroutine[Any, Any, Any]], action)
@functools.wraps(async_action)

View File

@ -2,10 +2,10 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine, Iterable, KeysView, Mapping
from datetime import datetime, timedelta
from functools import wraps
import inspect
import random
import re
import string
@ -125,7 +125,7 @@ class Throttle:
def __call__(self, method: Callable) -> Callable:
"""Caller for the throttle."""
# Make sure we return a coroutine if the method is async.
if asyncio.iscoroutinefunction(method):
if inspect.iscoroutinefunction(method):
async def throttled_value() -> None:
"""Stand-in function for when real func is being throttled."""

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import asyncio
import inspect
from typing import Any
from unittest.mock import AsyncMock, MagicMock
@ -191,7 +191,7 @@ async def trigger_subscription_callback(
object_id=object_id,
data=data,
)
if asyncio.iscoroutinefunction(cb_func):
if inspect.iscoroutinefunction(cb_func):
await cb_func(event)
else:
cb_func(event)

View File

@ -2,6 +2,7 @@
import asyncio
from functools import partial
import inspect
import logging
import queue
from unittest.mock import patch
@ -102,7 +103,7 @@ def test_catch_log_exception() -> None:
async def async_meth():
pass
assert asyncio.iscoroutinefunction(
assert inspect.iscoroutinefunction(
logging_util.catch_log_exception(partial(async_meth), lambda: None)
)
@ -120,7 +121,7 @@ def test_catch_log_exception() -> None:
wrapped = logging_util.catch_log_exception(partial(sync_meth), lambda: None)
assert not is_callback(wrapped)
assert not asyncio.iscoroutinefunction(wrapped)
assert not inspect.iscoroutinefunction(wrapped)
@pytest.mark.no_fail_on_log_exception