Add debug mode to catch unsafe thread operations using core helpers (#115390)

* adjust

* adjust

* fixes

* one more

* test

* debug

* move to config

* cover

* Update homeassistant/core.py

* set debug from RuntimeConfig

* reduce

* fix message

* raise

* Update homeassistant/core.py

* Update homeassistant/core.py

* no flood check for raise

* cover
This commit is contained in:
J. Nick Koston 2024-04-24 03:36:05 +02:00 committed by GitHub
parent 9d54aa205b
commit 53a179088f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 197 additions and 10 deletions

View File

@ -253,6 +253,8 @@ async def async_setup_hass(
runtime_config.log_no_color,
)
if runtime_config.debug:
hass.config.debug = True
hass.config.safe_mode = runtime_config.safe_mode
hass.config.skip_pip = runtime_config.skip_pip
hass.config.skip_pip_packages = runtime_config.skip_pip_packages

View File

@ -39,6 +39,7 @@ from .const import (
CONF_CUSTOMIZE,
CONF_CUSTOMIZE_DOMAIN,
CONF_CUSTOMIZE_GLOB,
CONF_DEBUG,
CONF_ELEVATION,
CONF_EXTERNAL_URL,
CONF_ID,
@ -391,6 +392,7 @@ CORE_CONFIG_SCHEMA = vol.All(
vol.Optional(CONF_CURRENCY): _validate_currency,
vol.Optional(CONF_COUNTRY): cv.country,
vol.Optional(CONF_LANGUAGE): cv.language,
vol.Optional(CONF_DEBUG): cv.boolean,
}
),
_filter_bad_internal_external_urls,
@ -899,6 +901,9 @@ async def async_process_ha_core_config(hass: HomeAssistant, config: dict) -> Non
if key in config:
setattr(hac, attr, config[key])
if config.get(CONF_DEBUG):
hac.debug = True
_raise_issue_if_legacy_templates(hass, config.get(CONF_LEGACY_TEMPLATES))
_raise_issue_if_historic_currency(hass, hass.config.currency)
_raise_issue_if_no_country(hass, hass.config.country)

View File

@ -296,6 +296,7 @@ CONF_WHILE: Final = "while"
CONF_WHITELIST: Final = "whitelist"
CONF_ALLOWLIST_EXTERNAL_DIRS: Final = "allowlist_external_dirs"
LEGACY_CONF_WHITELIST_EXTERNAL_DIRS: Final = "whitelist_external_dirs"
CONF_DEBUG: Final = "debug"
CONF_XY: Final = "xy"
CONF_ZONE: Final = "zone"

View File

@ -429,6 +429,20 @@ class HomeAssistant:
max_workers=1, thread_name_prefix="ImportExecutor"
)
def verify_event_loop_thread(self, what: str) -> None:
"""Report and raise if we are not running in the event loop thread."""
if (
loop_thread_ident := self.loop.__dict__.get("_thread_ident")
) and loop_thread_ident != threading.get_ident():
from .helpers import frame # pylint: disable=import-outside-toplevel
# frame is a circular import, so we import it here
frame.report(
f"calls {what} from a thread",
error_if_core=True,
error_if_integration=True,
)
@property
def _active_tasks(self) -> set[asyncio.Future[Any]]:
"""Return all active tasks.
@ -503,7 +517,6 @@ class HomeAssistant:
This method is a coroutine.
"""
_LOGGER.info("Starting Home Assistant")
setattr(self.loop, "_thread_ident", threading.get_ident())
self.set_state(CoreState.starting)
self.bus.async_fire_internal(EVENT_CORE_CONFIG_UPDATE)
@ -1451,6 +1464,9 @@ class EventBus:
This method must be run in the event loop.
"""
if self._hass.config.debug:
self._hass.verify_event_loop_thread("async_fire")
if len(event_type) > MAX_LENGTH_EVENT_EVENT_TYPE:
raise MaxLengthExceeded(
event_type, "event_type", MAX_LENGTH_EVENT_EVENT_TYPE
@ -2749,6 +2765,7 @@ class Config:
self.elevation: int = 0
"""Elevation (always in meters regardless of the unit system)."""
self.debug: bool = False
self.location_name: str = "Home"
self.time_zone: str = "UTC"
self.units: UnitSystem = METRIC_SYSTEM
@ -2889,6 +2906,7 @@ class Config:
"country": self.country,
"language": self.language,
"safe_mode": self.safe_mode,
"debug": self.debug,
}
def set_time_zone(self, time_zone_str: str) -> None:

View File

@ -199,6 +199,9 @@ def async_dispatcher_send(
This method must be run in the event loop.
"""
if hass.config.debug:
hass.verify_event_loop_thread("async_dispatcher_send")
if (maybe_dispatchers := hass.data.get(DATA_DISPATCHER)) is None:
return
dispatchers: _DispatcherDataType[*_Ts] = maybe_dispatchers

View File

@ -971,6 +971,8 @@ class Entity(
"""Write the state to the state machine."""
if self.hass is None:
raise RuntimeError(f"Attribute hass is None for {self}")
if self.hass.config.debug:
self.hass.verify_event_loop_thread("async_write_ha_state")
# The check for self.platform guards against integrations not using an
# EntityComponent and can be removed in HA Core 2024.1

View File

@ -136,6 +136,7 @@ def report(
error_if_core: bool = True,
level: int = logging.WARNING,
log_custom_component_only: bool = False,
error_if_integration: bool = False,
) -> None:
"""Report incorrect usage.
@ -153,14 +154,19 @@ def report(
_LOGGER.warning(msg, stack_info=True)
return
if not log_custom_component_only or integration_frame.custom_integration:
_report_integration(what, integration_frame, level)
if (
error_if_integration
or not log_custom_component_only
or integration_frame.custom_integration
):
_report_integration(what, integration_frame, level, error_if_integration)
def _report_integration(
what: str,
integration_frame: IntegrationFrame,
level: int = logging.WARNING,
error: bool = False,
) -> None:
"""Report incorrect usage in an integration.
@ -168,7 +174,7 @@ def _report_integration(
"""
# Keep track of integrations already reported to prevent flooding
key = f"{integration_frame.filename}:{integration_frame.line_number}"
if key in _REPORTED_INTEGRATIONS:
if not error and key in _REPORTED_INTEGRATIONS:
return
_REPORTED_INTEGRATIONS.add(key)
@ -180,11 +186,11 @@ def _report_integration(
integration_domain=integration_frame.integration,
module=integration_frame.module,
)
integration_type = "custom " if integration_frame.custom_integration else ""
_LOGGER.log(
level,
"Detected that %sintegration '%s' %s at %s, line %s: %s, please %s",
"custom " if integration_frame.custom_integration else "",
integration_type,
integration_frame.integration,
what,
integration_frame.relative_filename,
@ -192,6 +198,15 @@ def _report_integration(
integration_frame.line,
report_issue,
)
if not error:
return
raise RuntimeError(
f"Detected that {integration_type}integration "
f"'{integration_frame.integration}' {what} at "
f"{integration_frame.relative_filename}, line "
f"{integration_frame.line_number}: {integration_frame.line}. "
f"Please {report_issue}."
)
def warn_use(func: _CallableT, what: str) -> _CallableT:

View File

@ -695,6 +695,8 @@ class Template:
**kwargs: Any,
) -> RenderInfo:
"""Render the template and collect an entity filter."""
if self.hass and self.hass.config.debug:
self.hass.verify_event_loop_thread("async_render_to_info")
self._renders += 1
assert self.hass and _render_info.get() is None

View File

@ -107,6 +107,7 @@ class HassEventLoopPolicy(asyncio.DefaultEventLoopPolicy):
def new_event_loop(self) -> asyncio.AbstractEventLoop:
"""Get the event loop."""
loop: asyncio.AbstractEventLoop = super().new_event_loop()
setattr(loop, "_thread_ident", threading.get_ident())
loop.set_exception_handler(_async_loop_exception_handler)
if self.debug:
loop.set_debug(True)

View File

@ -52,8 +52,7 @@ def run_callback_threadsafe(
Return a concurrent.futures.Future to access the result.
"""
ident = loop.__dict__.get("_thread_ident")
if ident is not None and ident == threading.get_ident():
if (ident := loop.__dict__.get("_thread_ident")) and ident == threading.get_ident():
raise RuntimeError("Cannot be called from within the event loop")
future: concurrent.futures.Future[_T] = concurrent.futures.Future()

View File

@ -239,3 +239,24 @@ async def test_dispatcher_add_dispatcher(hass: HomeAssistant) -> None:
async_dispatcher_send(hass, "test", 5)
assert calls == [3, 4, 4, 5, 5]
async def test_thread_safety_checks(hass: HomeAssistant) -> None:
"""Test dispatcher thread safety checks."""
hass.config.debug = True
calls = []
@callback
def _dispatcher(data):
calls.append(data)
async_dispatcher_connect(hass, "test", _dispatcher)
with pytest.raises(
RuntimeError,
match="Detected code that calls async_dispatcher_send from a thread.",
):
await hass.async_add_executor_job(async_dispatcher_send, hass, "test", 3)
async_dispatcher_send(hass, "test", 4)
assert calls == [4]

View File

@ -2594,3 +2594,24 @@ async def test_get_hassjob_type(hass: HomeAssistant) -> None:
assert ent_1.get_hassjob_type("update") is HassJobType.Executor
assert ent_1.get_hassjob_type("async_update") is HassJobType.Coroutinefunction
assert ent_1.get_hassjob_type("update_callback") is HassJobType.Callback
async def test_async_write_ha_state_thread_safety(hass: HomeAssistant) -> None:
"""Test async_write_ha_state thread safety."""
hass.config.debug = True
ent = entity.Entity()
ent.entity_id = "test.any"
ent.hass = hass
ent.async_write_ha_state()
assert hass.states.get(ent.entity_id)
ent2 = entity.Entity()
ent2.entity_id = "test.any2"
ent2.hass = hass
with pytest.raises(
RuntimeError,
match="Detected code that calls async_write_ha_state from a thread.",
):
await hass.async_add_executor_job(ent2.async_write_ha_state)
assert not hass.states.get(ent2.entity_id)

View File

@ -205,3 +205,45 @@ async def test_report_missing_integration_frame(
frame.report(what, error_if_core=False, log_custom_component_only=True)
assert caplog.text == ""
@pytest.mark.parametrize("run_count", [1, 2])
# Run this twice to make sure the flood check does not
# kick in when error_if_integration=True
async def test_report_error_if_integration(
caplog: pytest.LogCaptureFixture, run_count: int
) -> None:
"""Test RuntimeError is raised if error_if_integration is set."""
frames = extract_stack_to_frame(
[
Mock(
filename="/home/paulus/homeassistant/core.py",
lineno="23",
line="do_something()",
),
Mock(
filename="/home/paulus/homeassistant/components/hue/light.py",
lineno="23",
line="self.light.is_on",
),
Mock(
filename="/home/paulus/aiohue/lights.py",
lineno="2",
line="something()",
),
]
)
with (
patch(
"homeassistant.helpers.frame.get_current_frame",
return_value=frames,
),
pytest.raises(
RuntimeError,
match=(
"Detected that integration 'hue' did a bad"
" thing at homeassistant/components/hue/light.py"
),
),
):
frame.report("did a bad thing", error_if_integration=True)

View File

@ -5757,3 +5757,20 @@ async def test_label_areas(
info = render_to_info(hass, f"{{{{ '{label.name}' | label_areas }}}}")
assert_result_info(info, [master_bedroom.id])
assert info.rate_limit is None
async def test_template_thread_safety_checks(hass: HomeAssistant) -> None:
"""Test template thread safety checks."""
hass.states.async_set("sensor.test", "23")
template_str = "{{ states('sensor.test') }}"
template_obj = template.Template(template_str, None)
template_obj.hass = hass
hass.config.debug = True
with pytest.raises(
RuntimeError,
match="Detected code that calls async_render_to_info from a thread.",
):
await hass.async_add_executor_job(template_obj.async_render_to_info)
assert template_obj.async_render_to_info().result() == 23

View File

@ -13,7 +13,7 @@ import pytest
from homeassistant import bootstrap, loader, runner
import homeassistant.config as config_util
from homeassistant.config_entries import HANDLERS, ConfigEntry
from homeassistant.const import SIGNAL_BOOTSTRAP_INTEGRATIONS
from homeassistant.const import CONF_DEBUG, SIGNAL_BOOTSTRAP_INTEGRATIONS
from homeassistant.core import CoreState, HomeAssistant, async_get_hass, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.dispatcher import async_dispatcher_connect
@ -112,6 +112,16 @@ async def test_empty_setup(hass: HomeAssistant) -> None:
assert domain in hass.config.components, domain
@pytest.mark.parametrize("load_registries", [False])
async def test_config_does_not_turn_off_debug(hass: HomeAssistant) -> None:
"""Test that config does not turn off debug if its turned on by runtime config."""
# Mock that its turned on from RuntimeConfig
hass.config.debug = True
await bootstrap.async_from_config_dict({CONF_DEBUG: False}, hass)
assert hass.config.debug is True
@pytest.mark.parametrize("load_registries", [False])
async def test_preload_translations(hass: HomeAssistant) -> None:
"""Test translations are preloaded for all frontend deps and base platforms."""
@ -599,6 +609,7 @@ async def test_setup_hass(
log_no_color=log_no_color,
skip_pip=True,
recovery_mode=False,
debug=True,
),
)
@ -619,6 +630,9 @@ async def test_setup_hass(
assert len(mock_ensure_config_exists.mock_calls) == 1
assert len(mock_process_ha_config_upgrade.mock_calls) == 1
# debug in RuntimeConfig should set it it in hass.config
assert hass.config.debug is True
assert hass == async_get_hass()

View File

@ -857,6 +857,7 @@ async def test_loading_configuration(hass: HomeAssistant) -> None:
"internal_url": "http://example.local",
"media_dirs": {"mymedia": "/usr"},
"legacy_templates": True,
"debug": True,
"currency": "EUR",
"country": "SE",
"language": "sv",
@ -877,6 +878,7 @@ async def test_loading_configuration(hass: HomeAssistant) -> None:
assert hass.config.media_dirs == {"mymedia": "/usr"}
assert hass.config.config_source is ConfigSource.YAML
assert hass.config.legacy_templates is True
assert hass.config.debug is True
assert hass.config.currency == "EUR"
assert hass.config.country == "SE"
assert hass.config.language == "sv"

View File

@ -1990,6 +1990,7 @@ async def test_config_as_dict() -> None:
"country": None,
"language": "en",
"safe_mode": False,
"debug": False,
}
assert expected == config.as_dict()
@ -3439,3 +3440,22 @@ async def test_top_level_components(hass: HomeAssistant) -> None:
hass.config.components.remove("homeassistant.scene")
with pytest.raises(NotImplementedError):
hass.config.components.discard("homeassistant")
async def test_debug_mode_defaults_to_off(hass: HomeAssistant) -> None:
"""Test debug mode defaults to off."""
assert not hass.config.debug
async def test_async_fire_thread_safety(hass: HomeAssistant) -> None:
"""Test async_fire thread safety."""
hass.config.debug = True
events = async_capture_events(hass, "test_event")
hass.bus.async_fire("test_event")
with pytest.raises(
RuntimeError, match="Detected code that calls async_fire from a thread."
):
await hass.async_add_executor_job(hass.bus.async_fire, "test_event")
assert len(events) == 1

View File

@ -76,7 +76,8 @@ async def test_run_callback_threadsafe(hass: HomeAssistant) -> None:
nonlocal it_ran
it_ran = True
assert hasync.run_callback_threadsafe(hass.loop, callback)
with patch.dict(hass.loop.__dict__, {"_thread_ident": -1}):
assert hasync.run_callback_threadsafe(hass.loop, callback)
assert it_ran is False
# Verify that async_block_till_done will flush
@ -95,6 +96,7 @@ async def test_callback_is_always_scheduled(hass: HomeAssistant) -> None:
hasync.shutdown_run_callback_threadsafe(hass.loop)
with (
patch.dict(hass.loop.__dict__, {"_thread_ident": -1}),
patch.object(hass.loop, "call_soon_threadsafe") as mock_call_soon_threadsafe,
pytest.raises(RuntimeError),
):