diff --git a/homeassistant/core.py b/homeassistant/core.py index 89269ae9158..bd596780759 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -1261,7 +1261,7 @@ class State: "State max length is 255 characters." ) - self.entity_id = entity_id.lower() + self.entity_id = entity_id self.state = state self.attributes = ReadOnlyDict(attributes or {}) self.last_updated = last_updated or dt_util.utcnow() @@ -1412,11 +1412,12 @@ class State: class StateMachine: """Helper class that tracks the state of different entities.""" - __slots__ = ("_states", "_reservations", "_bus", "_loop") + __slots__ = ("_states", "_domain_index", "_reservations", "_bus", "_loop") def __init__(self, bus: EventBus, loop: asyncio.events.AbstractEventLoop) -> None: """Initialize state machine.""" self._states: dict[str, State] = {} + self._domain_index: dict[str, dict[str, State]] = {} self._reservations: set[str] = set() self._bus = bus self._loop = loop @@ -1440,13 +1441,13 @@ class StateMachine: return list(self._states) if isinstance(domain_filter, str): - domain_filter = (domain_filter.lower(),) + return list(self._domain_index.get(domain_filter.lower(), ())) - return [ - state.entity_id - for state in self._states.values() - if state.domain in domain_filter - ] + states: list[str] = [] + for domain in domain_filter: + if domain_index := self._domain_index.get(domain): + states.extend(domain_index) + return states @callback def async_entity_ids_count( @@ -1460,11 +1461,9 @@ class StateMachine: return len(self._states) if isinstance(domain_filter, str): - domain_filter = (domain_filter.lower(),) + return len(self._domain_index.get(domain_filter.lower(), ())) - return len( - [None for state in self._states.values() if state.domain in domain_filter] - ) + return sum(len(self._domain_index.get(domain, ())) for domain in domain_filter) def all(self, domain_filter: str | Iterable[str] | None = None) -> list[State]: """Create a list of all states.""" @@ -1484,11 +1483,13 @@ class StateMachine: return list(self._states.values()) if isinstance(domain_filter, str): - domain_filter = (domain_filter.lower(),) + return list(self._domain_index.get(domain_filter.lower(), {}).values()) - return [ - state for state in self._states.values() if state.domain in domain_filter - ] + states: list[State] = [] + for domain in domain_filter: + if domain_index := self._domain_index.get(domain): + states.extend(domain_index.values()) + return states def get(self, entity_id: str) -> State | None: """Retrieve state of entity_id or None if not found. @@ -1524,13 +1525,12 @@ class StateMachine: """ entity_id = entity_id.lower() old_state = self._states.pop(entity_id, None) - - if entity_id in self._reservations: - self._reservations.remove(entity_id) + self._reservations.discard(entity_id) if old_state is None: return False + self._domain_index[old_state.domain].pop(entity_id) old_state.expire() self._bus.async_fire( EVENT_STATE_CHANGED, @@ -1652,6 +1652,10 @@ class StateMachine: if old_state is not None: old_state.expire() self._states[entity_id] = state + if not (domain_index := self._domain_index.get(state.domain)): + domain_index = {} + self._domain_index[state.domain] = domain_index + domain_index[entity_id] = state self._bus.async_fire( EVENT_STATE_CHANGED, {"entity_id": entity_id, "old_state": old_state, "new_state": state}, diff --git a/tests/test_core.py b/tests/test_core.py index 4f7916e757b..f4a80468050 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1938,6 +1938,7 @@ async def test_async_entity_ids_count(hass: HomeAssistant) -> None: assert hass.states.async_entity_ids_count() == 5 assert hass.states.async_entity_ids_count("light") == 3 + assert hass.states.async_entity_ids_count({"light", "vacuum"}) == 4 async def test_hassjob_forbid_coroutine() -> None: