diff --git a/homeassistant/core.py b/homeassistant/core.py index 17b8b5f2e85..cbfc8097c7f 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -6,7 +6,16 @@ of entities and react to changes. from __future__ import annotations import asyncio -from collections.abc import Callable, Collection, Coroutine, Iterable, Mapping +from collections import UserDict, defaultdict +from collections.abc import ( + Callable, + Collection, + Coroutine, + Iterable, + KeysView, + Mapping, + ValuesView, +) import concurrent.futures from contextlib import suppress import datetime @@ -1413,15 +1422,59 @@ class State: ) +class States(UserDict[str, State]): + """Container for states, maps entity_id -> State. + + Maintains an additional index: + - domain -> dict[str, State] + """ + + def __init__(self) -> None: + """Initialize the container.""" + super().__init__() + self._domain_index: defaultdict[str, dict[str, State]] = defaultdict(dict) + + def values(self) -> ValuesView[State]: + """Return the underlying values to avoid __iter__ overhead.""" + return self.data.values() + + def __setitem__(self, key: str, entry: State) -> None: + """Add an item.""" + self.data[key] = entry + self._domain_index[entry.domain][entry.entity_id] = entry + + def __delitem__(self, key: str) -> None: + """Remove an item.""" + entry = self[key] + del self._domain_index[entry.domain][entry.entity_id] + super().__delitem__(key) + + def domain_entity_ids(self, key: str) -> KeysView[str] | tuple[()]: + """Get all entity_ids for a domain.""" + # Avoid polluting _domain_index with non-existing domains + if key not in self._domain_index: + return () + return self._domain_index[key].keys() + + def domain_states(self, key: str) -> ValuesView[State] | tuple[()]: + """Get all states for a domain.""" + # Avoid polluting _domain_index with non-existing domains + if key not in self._domain_index: + return () + return self._domain_index[key].values() + + class StateMachine: """Helper class that tracks the state of different entities.""" - __slots__ = ("_states", "_domain_index", "_reservations", "_bus", "_loop") + __slots__ = ("_states", "_states_data", "_reservations", "_bus", "_loop") def __init__(self, bus: EventBus, loop: asyncio.events.AbstractEventLoop) -> None: """Initialize state machine.""" - self._states: dict[str, State] = {} - self._domain_index: dict[str, dict[str, State]] = {} + self._states = States() + # _states_data is used to access the States backing dict directly to speed + # up read operations + self._states_data = self._states.data self._reservations: set[str] = set() self._bus = bus self._loop = loop @@ -1442,16 +1495,15 @@ class StateMachine: This method must be run in the event loop. """ if domain_filter is None: - return list(self._states) + return list(self._states_data) if isinstance(domain_filter, str): - return list(self._domain_index.get(domain_filter.lower(), ())) + return list(self._states.domain_entity_ids(domain_filter.lower())) - states: list[str] = [] + entity_ids: list[str] = [] for domain in domain_filter: - if domain_index := self._domain_index.get(domain): - states.extend(domain_index) - return states + entity_ids.extend(self._states.domain_entity_ids(domain)) + return entity_ids @callback def async_entity_ids_count( @@ -1462,12 +1514,14 @@ class StateMachine: This method must be run in the event loop. """ if domain_filter is None: - return len(self._states) + return len(self._states_data) if isinstance(domain_filter, str): - return len(self._domain_index.get(domain_filter.lower(), ())) + return len(self._states.domain_entity_ids(domain_filter.lower())) - return sum(len(self._domain_index.get(domain, ())) for domain in domain_filter) + return sum( + len(self._states.domain_entity_ids(domain)) for domain in domain_filter + ) def all(self, domain_filter: str | Iterable[str] | None = None) -> list[State]: """Create a list of all states.""" @@ -1484,15 +1538,14 @@ class StateMachine: This method must be run in the event loop. """ if domain_filter is None: - return list(self._states.values()) + return list(self._states_data.values()) if isinstance(domain_filter, str): - return list(self._domain_index.get(domain_filter.lower(), {}).values()) + return list(self._states.domain_states(domain_filter.lower())) states: list[State] = [] for domain in domain_filter: - if domain_index := self._domain_index.get(domain): - states.extend(domain_index.values()) + states.extend(self._states.domain_states(domain)) return states def get(self, entity_id: str) -> State | None: @@ -1500,7 +1553,7 @@ class StateMachine: Async friendly. """ - return self._states.get(entity_id.lower()) + return self._states_data.get(entity_id.lower()) def is_state(self, entity_id: str, state: str) -> bool: """Test if entity exists and is in specified state. @@ -1534,7 +1587,6 @@ class StateMachine: if old_state is None: return False - self._domain_index[old_state.domain].pop(entity_id) old_state.expire() self._bus.async_fire( EVENT_STATE_CHANGED, @@ -1579,7 +1631,7 @@ class StateMachine: entity_id are added. """ entity_id = entity_id.lower() - if entity_id in self._states or entity_id in self._reservations: + if entity_id in self._states_data or entity_id in self._reservations: raise HomeAssistantError( "async_reserve must not be called once the state is in the state" " machine." @@ -1591,7 +1643,9 @@ class StateMachine: def async_available(self, entity_id: str) -> bool: """Check to see if an entity_id is available to be used.""" entity_id = entity_id.lower() - return entity_id not in self._states and entity_id not in self._reservations + return ( + entity_id not in self._states_data and entity_id not in self._reservations + ) @callback def async_set( @@ -1614,7 +1668,7 @@ class StateMachine: entity_id = entity_id.lower() new_state = str(new_state) attributes = attributes or {} - if (old_state := self._states.get(entity_id)) is None: + if (old_state := self._states_data.get(entity_id)) is None: same_state = False same_attr = False last_changed = None @@ -1656,10 +1710,6 @@ class StateMachine: if old_state is not None: old_state.expire() self._states[entity_id] = state - if not (domain_index := self._domain_index.get(state.domain)): - domain_index = {} - self._domain_index[state.domain] = domain_index - domain_index[entity_id] = state self._bus.async_fire( EVENT_STATE_CHANGED, {"entity_id": entity_id, "old_state": old_state, "new_state": state}, diff --git a/tests/test_core.py b/tests/test_core.py index 5dcbb81db68..c5ce9eb0881 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -15,6 +15,7 @@ from typing import Any from unittest.mock import MagicMock, Mock, PropertyMock, patch import pytest +from pytest_unordered import unordered import voluptuous as vol from homeassistant.const import ( @@ -1031,17 +1032,18 @@ async def test_statemachine_is_state(hass: HomeAssistant) -> None: async def test_statemachine_entity_ids(hass: HomeAssistant) -> None: - """Test get_entity_ids method.""" + """Test async_entity_ids method.""" + assert hass.states.async_entity_ids() == [] + assert hass.states.async_entity_ids("light") == [] + assert hass.states.async_entity_ids(("light", "switch", "other")) == [] + hass.states.async_set("light.bowl", "on", {}) hass.states.async_set("SWITCH.AC", "off", {}) - ent_ids = hass.states.async_entity_ids() - assert len(ent_ids) == 2 - assert "light.bowl" in ent_ids - assert "switch.ac" in ent_ids - - ent_ids = hass.states.async_entity_ids("light") - assert len(ent_ids) == 1 - assert "light.bowl" in ent_ids + assert hass.states.async_entity_ids() == unordered(["light.bowl", "switch.ac"]) + assert hass.states.async_entity_ids("light") == ["light.bowl"] + assert hass.states.async_entity_ids(("light", "switch", "other")) == unordered( + ["light.bowl", "switch.ac"] + ) states = sorted(state.entity_id for state in hass.states.async_all()) assert states == ["light.bowl", "switch.ac"] @@ -1902,6 +1904,9 @@ async def test_chained_logging_misses_log_timeout( async def test_async_all(hass: HomeAssistant) -> None: """Test async_all.""" + assert hass.states.async_all() == [] + assert hass.states.async_all("light") == [] + assert hass.states.async_all(["light", "switch"]) == [] hass.states.async_set("switch.link", "on") hass.states.async_set("light.bowl", "on") @@ -1926,6 +1931,10 @@ async def test_async_all(hass: HomeAssistant) -> None: async def test_async_entity_ids_count(hass: HomeAssistant) -> None: """Test async_entity_ids_count.""" + assert hass.states.async_entity_ids_count() == 0 + assert hass.states.async_entity_ids_count("light") == 0 + assert hass.states.async_entity_ids_count({"light", "vacuum"}) == 0 + hass.states.async_set("switch.link", "on") hass.states.async_set("light.bowl", "on") hass.states.async_set("light.frog", "on")