mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 20:57:21 +00:00
Add debug mode to catch unsafe thread operations using core helpers (#115390)
* adjust * adjust * fixes * one more * test * debug * move to config * cover * Update homeassistant/core.py * set debug from RuntimeConfig * reduce * fix message * raise * Update homeassistant/core.py * Update homeassistant/core.py * no flood check for raise * cover
This commit is contained in:
parent
9d54aa205b
commit
53a179088f
@ -253,6 +253,8 @@ async def async_setup_hass(
|
||||
runtime_config.log_no_color,
|
||||
)
|
||||
|
||||
if runtime_config.debug:
|
||||
hass.config.debug = True
|
||||
hass.config.safe_mode = runtime_config.safe_mode
|
||||
hass.config.skip_pip = runtime_config.skip_pip
|
||||
hass.config.skip_pip_packages = runtime_config.skip_pip_packages
|
||||
|
@ -39,6 +39,7 @@ from .const import (
|
||||
CONF_CUSTOMIZE,
|
||||
CONF_CUSTOMIZE_DOMAIN,
|
||||
CONF_CUSTOMIZE_GLOB,
|
||||
CONF_DEBUG,
|
||||
CONF_ELEVATION,
|
||||
CONF_EXTERNAL_URL,
|
||||
CONF_ID,
|
||||
@ -391,6 +392,7 @@ CORE_CONFIG_SCHEMA = vol.All(
|
||||
vol.Optional(CONF_CURRENCY): _validate_currency,
|
||||
vol.Optional(CONF_COUNTRY): cv.country,
|
||||
vol.Optional(CONF_LANGUAGE): cv.language,
|
||||
vol.Optional(CONF_DEBUG): cv.boolean,
|
||||
}
|
||||
),
|
||||
_filter_bad_internal_external_urls,
|
||||
@ -899,6 +901,9 @@ async def async_process_ha_core_config(hass: HomeAssistant, config: dict) -> Non
|
||||
if key in config:
|
||||
setattr(hac, attr, config[key])
|
||||
|
||||
if config.get(CONF_DEBUG):
|
||||
hac.debug = True
|
||||
|
||||
_raise_issue_if_legacy_templates(hass, config.get(CONF_LEGACY_TEMPLATES))
|
||||
_raise_issue_if_historic_currency(hass, hass.config.currency)
|
||||
_raise_issue_if_no_country(hass, hass.config.country)
|
||||
|
@ -296,6 +296,7 @@ CONF_WHILE: Final = "while"
|
||||
CONF_WHITELIST: Final = "whitelist"
|
||||
CONF_ALLOWLIST_EXTERNAL_DIRS: Final = "allowlist_external_dirs"
|
||||
LEGACY_CONF_WHITELIST_EXTERNAL_DIRS: Final = "whitelist_external_dirs"
|
||||
CONF_DEBUG: Final = "debug"
|
||||
CONF_XY: Final = "xy"
|
||||
CONF_ZONE: Final = "zone"
|
||||
|
||||
|
@ -429,6 +429,20 @@ class HomeAssistant:
|
||||
max_workers=1, thread_name_prefix="ImportExecutor"
|
||||
)
|
||||
|
||||
def verify_event_loop_thread(self, what: str) -> None:
|
||||
"""Report and raise if we are not running in the event loop thread."""
|
||||
if (
|
||||
loop_thread_ident := self.loop.__dict__.get("_thread_ident")
|
||||
) and loop_thread_ident != threading.get_ident():
|
||||
from .helpers import frame # pylint: disable=import-outside-toplevel
|
||||
|
||||
# frame is a circular import, so we import it here
|
||||
frame.report(
|
||||
f"calls {what} from a thread",
|
||||
error_if_core=True,
|
||||
error_if_integration=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def _active_tasks(self) -> set[asyncio.Future[Any]]:
|
||||
"""Return all active tasks.
|
||||
@ -503,7 +517,6 @@ class HomeAssistant:
|
||||
This method is a coroutine.
|
||||
"""
|
||||
_LOGGER.info("Starting Home Assistant")
|
||||
setattr(self.loop, "_thread_ident", threading.get_ident())
|
||||
|
||||
self.set_state(CoreState.starting)
|
||||
self.bus.async_fire_internal(EVENT_CORE_CONFIG_UPDATE)
|
||||
@ -1451,6 +1464,9 @@ class EventBus:
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
if self._hass.config.debug:
|
||||
self._hass.verify_event_loop_thread("async_fire")
|
||||
|
||||
if len(event_type) > MAX_LENGTH_EVENT_EVENT_TYPE:
|
||||
raise MaxLengthExceeded(
|
||||
event_type, "event_type", MAX_LENGTH_EVENT_EVENT_TYPE
|
||||
@ -2749,6 +2765,7 @@ class Config:
|
||||
self.elevation: int = 0
|
||||
"""Elevation (always in meters regardless of the unit system)."""
|
||||
|
||||
self.debug: bool = False
|
||||
self.location_name: str = "Home"
|
||||
self.time_zone: str = "UTC"
|
||||
self.units: UnitSystem = METRIC_SYSTEM
|
||||
@ -2889,6 +2906,7 @@ class Config:
|
||||
"country": self.country,
|
||||
"language": self.language,
|
||||
"safe_mode": self.safe_mode,
|
||||
"debug": self.debug,
|
||||
}
|
||||
|
||||
def set_time_zone(self, time_zone_str: str) -> None:
|
||||
|
@ -199,6 +199,9 @@ def async_dispatcher_send(
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
if hass.config.debug:
|
||||
hass.verify_event_loop_thread("async_dispatcher_send")
|
||||
|
||||
if (maybe_dispatchers := hass.data.get(DATA_DISPATCHER)) is None:
|
||||
return
|
||||
dispatchers: _DispatcherDataType[*_Ts] = maybe_dispatchers
|
||||
|
@ -971,6 +971,8 @@ class Entity(
|
||||
"""Write the state to the state machine."""
|
||||
if self.hass is None:
|
||||
raise RuntimeError(f"Attribute hass is None for {self}")
|
||||
if self.hass.config.debug:
|
||||
self.hass.verify_event_loop_thread("async_write_ha_state")
|
||||
|
||||
# The check for self.platform guards against integrations not using an
|
||||
# EntityComponent and can be removed in HA Core 2024.1
|
||||
|
@ -136,6 +136,7 @@ def report(
|
||||
error_if_core: bool = True,
|
||||
level: int = logging.WARNING,
|
||||
log_custom_component_only: bool = False,
|
||||
error_if_integration: bool = False,
|
||||
) -> None:
|
||||
"""Report incorrect usage.
|
||||
|
||||
@ -153,14 +154,19 @@ def report(
|
||||
_LOGGER.warning(msg, stack_info=True)
|
||||
return
|
||||
|
||||
if not log_custom_component_only or integration_frame.custom_integration:
|
||||
_report_integration(what, integration_frame, level)
|
||||
if (
|
||||
error_if_integration
|
||||
or not log_custom_component_only
|
||||
or integration_frame.custom_integration
|
||||
):
|
||||
_report_integration(what, integration_frame, level, error_if_integration)
|
||||
|
||||
|
||||
def _report_integration(
|
||||
what: str,
|
||||
integration_frame: IntegrationFrame,
|
||||
level: int = logging.WARNING,
|
||||
error: bool = False,
|
||||
) -> None:
|
||||
"""Report incorrect usage in an integration.
|
||||
|
||||
@ -168,7 +174,7 @@ def _report_integration(
|
||||
"""
|
||||
# Keep track of integrations already reported to prevent flooding
|
||||
key = f"{integration_frame.filename}:{integration_frame.line_number}"
|
||||
if key in _REPORTED_INTEGRATIONS:
|
||||
if not error and key in _REPORTED_INTEGRATIONS:
|
||||
return
|
||||
_REPORTED_INTEGRATIONS.add(key)
|
||||
|
||||
@ -180,11 +186,11 @@ def _report_integration(
|
||||
integration_domain=integration_frame.integration,
|
||||
module=integration_frame.module,
|
||||
)
|
||||
|
||||
integration_type = "custom " if integration_frame.custom_integration else ""
|
||||
_LOGGER.log(
|
||||
level,
|
||||
"Detected that %sintegration '%s' %s at %s, line %s: %s, please %s",
|
||||
"custom " if integration_frame.custom_integration else "",
|
||||
integration_type,
|
||||
integration_frame.integration,
|
||||
what,
|
||||
integration_frame.relative_filename,
|
||||
@ -192,6 +198,15 @@ def _report_integration(
|
||||
integration_frame.line,
|
||||
report_issue,
|
||||
)
|
||||
if not error:
|
||||
return
|
||||
raise RuntimeError(
|
||||
f"Detected that {integration_type}integration "
|
||||
f"'{integration_frame.integration}' {what} at "
|
||||
f"{integration_frame.relative_filename}, line "
|
||||
f"{integration_frame.line_number}: {integration_frame.line}. "
|
||||
f"Please {report_issue}."
|
||||
)
|
||||
|
||||
|
||||
def warn_use(func: _CallableT, what: str) -> _CallableT:
|
||||
|
@ -695,6 +695,8 @@ class Template:
|
||||
**kwargs: Any,
|
||||
) -> RenderInfo:
|
||||
"""Render the template and collect an entity filter."""
|
||||
if self.hass and self.hass.config.debug:
|
||||
self.hass.verify_event_loop_thread("async_render_to_info")
|
||||
self._renders += 1
|
||||
assert self.hass and _render_info.get() is None
|
||||
|
||||
|
@ -107,6 +107,7 @@ class HassEventLoopPolicy(asyncio.DefaultEventLoopPolicy):
|
||||
def new_event_loop(self) -> asyncio.AbstractEventLoop:
|
||||
"""Get the event loop."""
|
||||
loop: asyncio.AbstractEventLoop = super().new_event_loop()
|
||||
setattr(loop, "_thread_ident", threading.get_ident())
|
||||
loop.set_exception_handler(_async_loop_exception_handler)
|
||||
if self.debug:
|
||||
loop.set_debug(True)
|
||||
|
@ -52,8 +52,7 @@ def run_callback_threadsafe(
|
||||
|
||||
Return a concurrent.futures.Future to access the result.
|
||||
"""
|
||||
ident = loop.__dict__.get("_thread_ident")
|
||||
if ident is not None and ident == threading.get_ident():
|
||||
if (ident := loop.__dict__.get("_thread_ident")) and ident == threading.get_ident():
|
||||
raise RuntimeError("Cannot be called from within the event loop")
|
||||
|
||||
future: concurrent.futures.Future[_T] = concurrent.futures.Future()
|
||||
|
@ -239,3 +239,24 @@ async def test_dispatcher_add_dispatcher(hass: HomeAssistant) -> None:
|
||||
async_dispatcher_send(hass, "test", 5)
|
||||
|
||||
assert calls == [3, 4, 4, 5, 5]
|
||||
|
||||
|
||||
async def test_thread_safety_checks(hass: HomeAssistant) -> None:
|
||||
"""Test dispatcher thread safety checks."""
|
||||
hass.config.debug = True
|
||||
calls = []
|
||||
|
||||
@callback
|
||||
def _dispatcher(data):
|
||||
calls.append(data)
|
||||
|
||||
async_dispatcher_connect(hass, "test", _dispatcher)
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Detected code that calls async_dispatcher_send from a thread.",
|
||||
):
|
||||
await hass.async_add_executor_job(async_dispatcher_send, hass, "test", 3)
|
||||
|
||||
async_dispatcher_send(hass, "test", 4)
|
||||
assert calls == [4]
|
||||
|
@ -2594,3 +2594,24 @@ async def test_get_hassjob_type(hass: HomeAssistant) -> None:
|
||||
assert ent_1.get_hassjob_type("update") is HassJobType.Executor
|
||||
assert ent_1.get_hassjob_type("async_update") is HassJobType.Coroutinefunction
|
||||
assert ent_1.get_hassjob_type("update_callback") is HassJobType.Callback
|
||||
|
||||
|
||||
async def test_async_write_ha_state_thread_safety(hass: HomeAssistant) -> None:
|
||||
"""Test async_write_ha_state thread safety."""
|
||||
hass.config.debug = True
|
||||
|
||||
ent = entity.Entity()
|
||||
ent.entity_id = "test.any"
|
||||
ent.hass = hass
|
||||
ent.async_write_ha_state()
|
||||
assert hass.states.get(ent.entity_id)
|
||||
|
||||
ent2 = entity.Entity()
|
||||
ent2.entity_id = "test.any2"
|
||||
ent2.hass = hass
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Detected code that calls async_write_ha_state from a thread.",
|
||||
):
|
||||
await hass.async_add_executor_job(ent2.async_write_ha_state)
|
||||
assert not hass.states.get(ent2.entity_id)
|
||||
|
@ -205,3 +205,45 @@ async def test_report_missing_integration_frame(
|
||||
|
||||
frame.report(what, error_if_core=False, log_custom_component_only=True)
|
||||
assert caplog.text == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize("run_count", [1, 2])
|
||||
# Run this twice to make sure the flood check does not
|
||||
# kick in when error_if_integration=True
|
||||
async def test_report_error_if_integration(
|
||||
caplog: pytest.LogCaptureFixture, run_count: int
|
||||
) -> None:
|
||||
"""Test RuntimeError is raised if error_if_integration is set."""
|
||||
frames = extract_stack_to_frame(
|
||||
[
|
||||
Mock(
|
||||
filename="/home/paulus/homeassistant/core.py",
|
||||
lineno="23",
|
||||
line="do_something()",
|
||||
),
|
||||
Mock(
|
||||
filename="/home/paulus/homeassistant/components/hue/light.py",
|
||||
lineno="23",
|
||||
line="self.light.is_on",
|
||||
),
|
||||
Mock(
|
||||
filename="/home/paulus/aiohue/lights.py",
|
||||
lineno="2",
|
||||
line="something()",
|
||||
),
|
||||
]
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.helpers.frame.get_current_frame",
|
||||
return_value=frames,
|
||||
),
|
||||
pytest.raises(
|
||||
RuntimeError,
|
||||
match=(
|
||||
"Detected that integration 'hue' did a bad"
|
||||
" thing at homeassistant/components/hue/light.py"
|
||||
),
|
||||
),
|
||||
):
|
||||
frame.report("did a bad thing", error_if_integration=True)
|
||||
|
@ -5757,3 +5757,20 @@ async def test_label_areas(
|
||||
info = render_to_info(hass, f"{{{{ '{label.name}' | label_areas }}}}")
|
||||
assert_result_info(info, [master_bedroom.id])
|
||||
assert info.rate_limit is None
|
||||
|
||||
|
||||
async def test_template_thread_safety_checks(hass: HomeAssistant) -> None:
|
||||
"""Test template thread safety checks."""
|
||||
hass.states.async_set("sensor.test", "23")
|
||||
template_str = "{{ states('sensor.test') }}"
|
||||
template_obj = template.Template(template_str, None)
|
||||
template_obj.hass = hass
|
||||
hass.config.debug = True
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Detected code that calls async_render_to_info from a thread.",
|
||||
):
|
||||
await hass.async_add_executor_job(template_obj.async_render_to_info)
|
||||
|
||||
assert template_obj.async_render_to_info().result() == 23
|
||||
|
@ -13,7 +13,7 @@ import pytest
|
||||
from homeassistant import bootstrap, loader, runner
|
||||
import homeassistant.config as config_util
|
||||
from homeassistant.config_entries import HANDLERS, ConfigEntry
|
||||
from homeassistant.const import SIGNAL_BOOTSTRAP_INTEGRATIONS
|
||||
from homeassistant.const import CONF_DEBUG, SIGNAL_BOOTSTRAP_INTEGRATIONS
|
||||
from homeassistant.core import CoreState, HomeAssistant, async_get_hass, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
@ -112,6 +112,16 @@ async def test_empty_setup(hass: HomeAssistant) -> None:
|
||||
assert domain in hass.config.components, domain
|
||||
|
||||
|
||||
@pytest.mark.parametrize("load_registries", [False])
|
||||
async def test_config_does_not_turn_off_debug(hass: HomeAssistant) -> None:
|
||||
"""Test that config does not turn off debug if its turned on by runtime config."""
|
||||
# Mock that its turned on from RuntimeConfig
|
||||
hass.config.debug = True
|
||||
|
||||
await bootstrap.async_from_config_dict({CONF_DEBUG: False}, hass)
|
||||
assert hass.config.debug is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("load_registries", [False])
|
||||
async def test_preload_translations(hass: HomeAssistant) -> None:
|
||||
"""Test translations are preloaded for all frontend deps and base platforms."""
|
||||
@ -599,6 +609,7 @@ async def test_setup_hass(
|
||||
log_no_color=log_no_color,
|
||||
skip_pip=True,
|
||||
recovery_mode=False,
|
||||
debug=True,
|
||||
),
|
||||
)
|
||||
|
||||
@ -619,6 +630,9 @@ async def test_setup_hass(
|
||||
assert len(mock_ensure_config_exists.mock_calls) == 1
|
||||
assert len(mock_process_ha_config_upgrade.mock_calls) == 1
|
||||
|
||||
# debug in RuntimeConfig should set it it in hass.config
|
||||
assert hass.config.debug is True
|
||||
|
||||
assert hass == async_get_hass()
|
||||
|
||||
|
||||
|
@ -857,6 +857,7 @@ async def test_loading_configuration(hass: HomeAssistant) -> None:
|
||||
"internal_url": "http://example.local",
|
||||
"media_dirs": {"mymedia": "/usr"},
|
||||
"legacy_templates": True,
|
||||
"debug": True,
|
||||
"currency": "EUR",
|
||||
"country": "SE",
|
||||
"language": "sv",
|
||||
@ -877,6 +878,7 @@ async def test_loading_configuration(hass: HomeAssistant) -> None:
|
||||
assert hass.config.media_dirs == {"mymedia": "/usr"}
|
||||
assert hass.config.config_source is ConfigSource.YAML
|
||||
assert hass.config.legacy_templates is True
|
||||
assert hass.config.debug is True
|
||||
assert hass.config.currency == "EUR"
|
||||
assert hass.config.country == "SE"
|
||||
assert hass.config.language == "sv"
|
||||
|
@ -1990,6 +1990,7 @@ async def test_config_as_dict() -> None:
|
||||
"country": None,
|
||||
"language": "en",
|
||||
"safe_mode": False,
|
||||
"debug": False,
|
||||
}
|
||||
|
||||
assert expected == config.as_dict()
|
||||
@ -3439,3 +3440,22 @@ async def test_top_level_components(hass: HomeAssistant) -> None:
|
||||
hass.config.components.remove("homeassistant.scene")
|
||||
with pytest.raises(NotImplementedError):
|
||||
hass.config.components.discard("homeassistant")
|
||||
|
||||
|
||||
async def test_debug_mode_defaults_to_off(hass: HomeAssistant) -> None:
|
||||
"""Test debug mode defaults to off."""
|
||||
assert not hass.config.debug
|
||||
|
||||
|
||||
async def test_async_fire_thread_safety(hass: HomeAssistant) -> None:
|
||||
"""Test async_fire thread safety."""
|
||||
hass.config.debug = True
|
||||
|
||||
events = async_capture_events(hass, "test_event")
|
||||
hass.bus.async_fire("test_event")
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Detected code that calls async_fire from a thread."
|
||||
):
|
||||
await hass.async_add_executor_job(hass.bus.async_fire, "test_event")
|
||||
|
||||
assert len(events) == 1
|
||||
|
@ -76,7 +76,8 @@ async def test_run_callback_threadsafe(hass: HomeAssistant) -> None:
|
||||
nonlocal it_ran
|
||||
it_ran = True
|
||||
|
||||
assert hasync.run_callback_threadsafe(hass.loop, callback)
|
||||
with patch.dict(hass.loop.__dict__, {"_thread_ident": -1}):
|
||||
assert hasync.run_callback_threadsafe(hass.loop, callback)
|
||||
assert it_ran is False
|
||||
|
||||
# Verify that async_block_till_done will flush
|
||||
@ -95,6 +96,7 @@ async def test_callback_is_always_scheduled(hass: HomeAssistant) -> None:
|
||||
hasync.shutdown_run_callback_threadsafe(hass.loop)
|
||||
|
||||
with (
|
||||
patch.dict(hass.loop.__dict__, {"_thread_ident": -1}),
|
||||
patch.object(hass.loop, "call_soon_threadsafe") as mock_call_soon_threadsafe,
|
||||
pytest.raises(RuntimeError),
|
||||
):
|
||||
|
Loading…
x
Reference in New Issue
Block a user