mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Fix memory leak when firing state_changed events (#72571)
This commit is contained in:
parent
465210784f
commit
049c06061c
@ -746,7 +746,7 @@ class LazyState(State):
|
|||||||
def context(self) -> Context: # type: ignore[override]
|
def context(self) -> Context: # type: ignore[override]
|
||||||
"""State context."""
|
"""State context."""
|
||||||
if self._context is None:
|
if self._context is None:
|
||||||
self._context = Context(id=None) # type: ignore[arg-type]
|
self._context = Context(id=None)
|
||||||
return self._context
|
return self._context
|
||||||
|
|
||||||
@context.setter
|
@context.setter
|
||||||
|
@ -37,7 +37,6 @@ from typing import (
|
|||||||
)
|
)
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import attr
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
import yarl
|
import yarl
|
||||||
|
|
||||||
@ -716,14 +715,26 @@ class HomeAssistant:
|
|||||||
self._stopped.set()
|
self._stopped.set()
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=False)
|
|
||||||
class Context:
|
class Context:
|
||||||
"""The context that triggered something."""
|
"""The context that triggered something."""
|
||||||
|
|
||||||
user_id: str | None = attr.ib(default=None)
|
__slots__ = ("user_id", "parent_id", "id", "origin_event")
|
||||||
parent_id: str | None = attr.ib(default=None)
|
|
||||||
id: str = attr.ib(factory=ulid_util.ulid)
|
def __init__(
|
||||||
origin_event: Event | None = attr.ib(default=None, eq=False)
|
self,
|
||||||
|
user_id: str | None = None,
|
||||||
|
parent_id: str | None = None,
|
||||||
|
id: str | None = None, # pylint: disable=redefined-builtin
|
||||||
|
) -> None:
|
||||||
|
"""Init the context."""
|
||||||
|
self.id = id or ulid_util.ulid()
|
||||||
|
self.user_id = user_id
|
||||||
|
self.parent_id = parent_id
|
||||||
|
self.origin_event: Event | None = None
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
"""Compare contexts."""
|
||||||
|
return bool(self.__class__ == other.__class__ and self.id == other.id)
|
||||||
|
|
||||||
def as_dict(self) -> dict[str, str | None]:
|
def as_dict(self) -> dict[str, str | None]:
|
||||||
"""Return a dictionary representation of the context."""
|
"""Return a dictionary representation of the context."""
|
||||||
@ -1163,6 +1174,24 @@ class State:
|
|||||||
context,
|
context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def expire(self) -> None:
|
||||||
|
"""Mark the state as old.
|
||||||
|
|
||||||
|
We give up the original reference to the context to ensure
|
||||||
|
the context can be garbage collected by replacing it with
|
||||||
|
a new one with the same id to ensure the old state
|
||||||
|
can still be examined for comparison against the new state.
|
||||||
|
|
||||||
|
Since we are always going to fire a EVENT_STATE_CHANGED event
|
||||||
|
after we remove a state from the state machine we need to make
|
||||||
|
sure we don't end up holding a reference to the original context
|
||||||
|
since it can never be garbage collected as each event would
|
||||||
|
reference the previous one.
|
||||||
|
"""
|
||||||
|
self.context = Context(
|
||||||
|
self.context.user_id, self.context.parent_id, self.context.id
|
||||||
|
)
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Return the comparison of the state."""
|
"""Return the comparison of the state."""
|
||||||
return ( # type: ignore[no-any-return]
|
return ( # type: ignore[no-any-return]
|
||||||
@ -1303,6 +1332,7 @@ class StateMachine:
|
|||||||
if old_state is None:
|
if old_state is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
old_state.expire()
|
||||||
self._bus.async_fire(
|
self._bus.async_fire(
|
||||||
EVENT_STATE_CHANGED,
|
EVENT_STATE_CHANGED,
|
||||||
{"entity_id": entity_id, "old_state": old_state, "new_state": None},
|
{"entity_id": entity_id, "old_state": old_state, "new_state": None},
|
||||||
@ -1396,7 +1426,6 @@ class StateMachine:
|
|||||||
|
|
||||||
if context is None:
|
if context is None:
|
||||||
context = Context(id=ulid_util.ulid(dt_util.utc_to_timestamp(now)))
|
context = Context(id=ulid_util.ulid(dt_util.utc_to_timestamp(now)))
|
||||||
|
|
||||||
state = State(
|
state = State(
|
||||||
entity_id,
|
entity_id,
|
||||||
new_state,
|
new_state,
|
||||||
@ -1406,6 +1435,8 @@ class StateMachine:
|
|||||||
context,
|
context,
|
||||||
old_state is None,
|
old_state is None,
|
||||||
)
|
)
|
||||||
|
if old_state is not None:
|
||||||
|
old_state.expire()
|
||||||
self._states[entity_id] = state
|
self._states[entity_id] = state
|
||||||
self._bus.async_fire(
|
self._bus.async_fire(
|
||||||
EVENT_STATE_CHANGED,
|
EVENT_STATE_CHANGED,
|
||||||
|
@ -6,9 +6,11 @@ import array
|
|||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import functools
|
import functools
|
||||||
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
|
from typing import Any
|
||||||
from unittest.mock import MagicMock, Mock, PropertyMock, patch
|
from unittest.mock import MagicMock, Mock, PropertyMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -1829,3 +1831,46 @@ async def test_event_context(hass):
|
|||||||
cancel2()
|
cancel2()
|
||||||
|
|
||||||
assert dummy_event2.context.origin_event == dummy_event
|
assert dummy_event2.context.origin_event == dummy_event
|
||||||
|
|
||||||
|
|
||||||
|
def _get_full_name(obj) -> str:
|
||||||
|
"""Get the full name of an object in memory."""
|
||||||
|
objtype = type(obj)
|
||||||
|
name = objtype.__name__
|
||||||
|
if module := getattr(objtype, "__module__", None):
|
||||||
|
return f"{module}.{name}"
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def _get_by_type(full_name: str) -> list[Any]:
|
||||||
|
"""Get all objects in memory with a specific type."""
|
||||||
|
return [obj for obj in gc.get_objects() if _get_full_name(obj) == full_name]
|
||||||
|
|
||||||
|
|
||||||
|
# The logger will hold a strong reference to the event for the life of the tests
|
||||||
|
# so we must patch it out
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not os.environ.get("DEBUG_MEMORY"),
|
||||||
|
reason="Takes too long on the CI",
|
||||||
|
)
|
||||||
|
@patch.object(ha._LOGGER, "debug", lambda *args: None)
|
||||||
|
async def test_state_changed_events_to_not_leak_contexts(hass):
|
||||||
|
"""Test state changed events do not leak contexts."""
|
||||||
|
gc.collect()
|
||||||
|
# Other tests can log Contexts which keep them in memory
|
||||||
|
# so we need to look at how many exist at the start
|
||||||
|
init_count = len(_get_by_type("homeassistant.core.Context"))
|
||||||
|
|
||||||
|
assert len(_get_by_type("homeassistant.core.Context")) == init_count
|
||||||
|
for i in range(20):
|
||||||
|
hass.states.async_set("light.switch", str(i))
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
assert len(_get_by_type("homeassistant.core.Context")) == init_count + 2
|
||||||
|
|
||||||
|
hass.states.async_remove("light.switch")
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
assert len(_get_by_type("homeassistant.core.Context")) == init_count
|
||||||
|
Loading…
x
Reference in New Issue
Block a user