Improve typing of entity.entity_sources (#99407)

* Improve typing of entity.entity_sources

* Calculate entity info source when generating WS response

* Adjust typing

* Update tests
This commit is contained in:
Erik Montnemery 2023-09-12 20:41:26 +02:00 committed by GitHub
parent cc252f705f
commit 51576b7214
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 39 additions and 26 deletions

View File

@ -707,7 +707,8 @@ class MediaPlayerCapabilities(AlexaEntity):
# AlexaEqualizerController is disabled for denonavr # AlexaEqualizerController is disabled for denonavr
# since it blocks alexa from discovering any devices. # since it blocks alexa from discovering any devices.
domain = entity_sources(self.hass).get(self.entity_id, {}).get("domain") entity_info = entity_sources(self.hass).get(self.entity_id)
domain = entity_info["domain"] if entity_info else None
if ( if (
supported & media_player.MediaPlayerEntityFeature.SELECT_SOUND_MODE supported & media_player.MediaPlayerEntityFeature.SELECT_SOUND_MODE
and domain != "denonavr" and domain != "denonavr"

View File

@ -40,6 +40,7 @@ from homeassistant.const import (
MAX_LENGTH_STATE_STATE, MAX_LENGTH_STATE_STATE,
) )
from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id
from homeassistant.helpers.entity import EntityInfo
from homeassistant.helpers.json import JSON_DUMP, json_bytes, json_bytes_strip_null from homeassistant.helpers.json import JSON_DUMP, json_bytes, json_bytes_strip_null
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from homeassistant.util.json import ( from homeassistant.util.json import (
@ -558,7 +559,7 @@ class StateAttributes(Base):
@staticmethod @staticmethod
def shared_attrs_bytes_from_event( def shared_attrs_bytes_from_event(
event: Event, event: Event,
entity_sources: dict[str, dict[str, str]], entity_sources: dict[str, EntityInfo],
exclude_attrs_by_domain: dict[str, set[str]], exclude_attrs_by_domain: dict[str, set[str]],
dialect: SupportedDialect | None, dialect: SupportedDialect | None,
) -> bytes: ) -> bytes:

View File

@ -15,7 +15,10 @@ from homeassistant.helpers import (
device_registry as dr, device_registry as dr,
entity_registry as er, entity_registry as er,
) )
from homeassistant.helpers.entity import entity_sources as get_entity_sources from homeassistant.helpers.entity import (
EntityInfo,
entity_sources as get_entity_sources,
)
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
DOMAIN = "search" DOMAIN = "search"
@ -97,7 +100,7 @@ class Searcher:
hass: HomeAssistant, hass: HomeAssistant,
device_reg: dr.DeviceRegistry, device_reg: dr.DeviceRegistry,
entity_reg: er.EntityRegistry, entity_reg: er.EntityRegistry,
entity_sources: dict[str, dict[str, str]], entity_sources: dict[str, EntityInfo],
) -> None: ) -> None:
"""Search results.""" """Search results."""
self.hass = hass self.hass = hass

View File

@ -262,8 +262,9 @@ def _normalize_states(
def _suggest_report_issue(hass: HomeAssistant, entity_id: str) -> str: def _suggest_report_issue(hass: HomeAssistant, entity_id: str) -> str:
"""Suggest to report an issue.""" """Suggest to report an issue."""
domain = entity_sources(hass).get(entity_id, {}).get("domain") entity_info = entity_sources(hass).get(entity_id)
custom_component = entity_sources(hass).get(entity_id, {}).get("custom_component") domain = entity_info["domain"] if entity_info else None
custom_component = entity_info["custom_component"] if entity_info else None
report_issue = "" report_issue = ""
if custom_component: if custom_component:
report_issue = "report it to the custom integration author." report_issue = "report it to the custom integration author."
@ -296,7 +297,8 @@ def warn_dip(
hass.data[WARN_DIP] = set() hass.data[WARN_DIP] = set()
if entity_id not in hass.data[WARN_DIP]: if entity_id not in hass.data[WARN_DIP]:
hass.data[WARN_DIP].add(entity_id) hass.data[WARN_DIP].add(entity_id)
domain = entity_sources(hass).get(entity_id, {}).get("domain") entity_info = entity_sources(hass).get(entity_id)
domain = entity_info["domain"] if entity_info else None
if domain in ["energy", "growatt_server", "solaredge"]: if domain in ["energy", "growatt_server", "solaredge"]:
return return
_LOGGER.warning( _LOGGER.warning(
@ -320,7 +322,8 @@ def warn_negative(hass: HomeAssistant, entity_id: str, state: State) -> None:
hass.data[WARN_NEGATIVE] = set() hass.data[WARN_NEGATIVE] = set()
if entity_id not in hass.data[WARN_NEGATIVE]: if entity_id not in hass.data[WARN_NEGATIVE]:
hass.data[WARN_NEGATIVE].add(entity_id) hass.data[WARN_NEGATIVE].add(entity_id)
domain = entity_sources(hass).get(entity_id, {}).get("domain") entity_info = entity_sources(hass).get(entity_id)
domain = entity_info["domain"] if entity_info else None
_LOGGER.warning( _LOGGER.warning(
( (
"Entity %s %shas state class total_increasing, but its state is " "Entity %s %shas state class total_increasing, but its state is "

View File

@ -596,7 +596,7 @@ async def handle_render_template(
def _serialize_entity_sources( def _serialize_entity_sources(
entity_infos: dict[str, dict[str, str]] entity_infos: dict[str, entity.EntityInfo]
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Prepare a websocket response from a dict of entity sources.""" """Prepare a websocket response from a dict of entity sources."""
result = {} result = {}

View File

@ -12,7 +12,16 @@ import logging
import math import math
import sys import sys
from timeit import default_timer as timer from timeit import default_timer as timer
from typing import TYPE_CHECKING, Any, Final, Literal, TypeVar, final from typing import (
TYPE_CHECKING,
Any,
Final,
Literal,
NotRequired,
TypedDict,
TypeVar,
final,
)
import voluptuous as vol import voluptuous as vol
@ -60,8 +69,6 @@ _T = TypeVar("_T")
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SLOW_UPDATE_WARNING = 10 SLOW_UPDATE_WARNING = 10
DATA_ENTITY_SOURCE = "entity_info" DATA_ENTITY_SOURCE = "entity_info"
SOURCE_CONFIG_ENTRY = "config_entry"
SOURCE_PLATFORM_CONFIG = "platform_config"
# Used when converting float states to string: limit precision according to machine # Used when converting float states to string: limit precision according to machine
# epsilon to make the string representation readable # epsilon to make the string representation readable
@ -76,9 +83,9 @@ def async_setup(hass: HomeAssistant) -> None:
@callback @callback
@bind_hass @bind_hass
def entity_sources(hass: HomeAssistant) -> dict[str, dict[str, str]]: def entity_sources(hass: HomeAssistant) -> dict[str, EntityInfo]:
"""Get the entity sources.""" """Get the entity sources."""
_entity_sources: dict[str, dict[str, str]] = hass.data[DATA_ENTITY_SOURCE] _entity_sources: dict[str, EntityInfo] = hass.data[DATA_ENTITY_SOURCE]
return _entity_sources return _entity_sources
@ -181,6 +188,14 @@ def get_unit_of_measurement(hass: HomeAssistant, entity_id: str) -> str | None:
ENTITY_CATEGORIES_SCHEMA: Final = vol.Coerce(EntityCategory) ENTITY_CATEGORIES_SCHEMA: Final = vol.Coerce(EntityCategory)
class EntityInfo(TypedDict):
"""Entity info."""
domain: str
custom_component: bool
config_entry: NotRequired[str]
class EntityPlatformState(Enum): class EntityPlatformState(Enum):
"""The platform state of an entity.""" """The platform state of an entity."""
@ -1061,18 +1076,15 @@ class Entity(ABC):
Not to be extended by integrations. Not to be extended by integrations.
""" """
info = { info: EntityInfo = {
"domain": self.platform.platform_name, "domain": self.platform.platform_name,
"custom_component": "custom_components" in type(self).__module__, "custom_component": "custom_components" in type(self).__module__,
} }
if self.platform.config_entry: if self.platform.config_entry:
info["source"] = SOURCE_CONFIG_ENTRY
info["config_entry"] = self.platform.config_entry.entry_id info["config_entry"] = self.platform.config_entry.entry_id
else:
info["source"] = SOURCE_PLATFORM_CONFIG
self.hass.data[DATA_ENTITY_SOURCE][self.entity_id] = info entity_sources(self.hass)[self.entity_id] = info
if self.registry_entry is not None: if self.registry_entry is not None:
# This is an assert as it should never happen, but helps in tests # This is an assert as it should never happen, but helps in tests

View File

@ -6,7 +6,6 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers import ( from homeassistant.helpers import (
area_registry as ar, area_registry as ar,
device_registry as dr, device_registry as dr,
entity,
entity_registry as er, entity_registry as er,
) )
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@ -22,11 +21,9 @@ def stub_blueprint_populate_autouse(stub_blueprint_populate: None) -> None:
MOCK_ENTITY_SOURCES = { MOCK_ENTITY_SOURCES = {
"light.platform_config_source": { "light.platform_config_source": {
"source": entity.SOURCE_PLATFORM_CONFIG,
"domain": "wled", "domain": "wled",
}, },
"light.config_entry_source": { "light.config_entry_source": {
"source": entity.SOURCE_CONFIG_ENTRY,
"config_entry": "config_entry_id", "config_entry": "config_entry_id",
"domain": "wled", "domain": "wled",
}, },
@ -73,11 +70,9 @@ async def test_search(
entity_sources = { entity_sources = {
"light.wled_platform_config_source": { "light.wled_platform_config_source": {
"source": entity.SOURCE_PLATFORM_CONFIG,
"domain": "wled", "domain": "wled",
}, },
"light.wled_config_entry_source": { "light.wled_config_entry_source": {
"source": entity.SOURCE_CONFIG_ENTRY,
"config_entry": wled_config_entry.entry_id, "config_entry": wled_config_entry.entry_id,
"domain": "wled", "domain": "wled",
}, },

View File

@ -795,13 +795,11 @@ async def test_setup_source(hass: HomeAssistant) -> None:
"test_domain.platform_config_source": { "test_domain.platform_config_source": {
"custom_component": False, "custom_component": False,
"domain": "test_platform", "domain": "test_platform",
"source": entity.SOURCE_PLATFORM_CONFIG,
}, },
"test_domain.config_entry_source": { "test_domain.config_entry_source": {
"config_entry": platform.config_entry.entry_id, "config_entry": platform.config_entry.entry_id,
"custom_component": False, "custom_component": False,
"domain": "test_platform", "domain": "test_platform",
"source": entity.SOURCE_CONFIG_ENTRY,
}, },
} }