Refactor EntityPlatform (#147927)

This commit is contained in:
Erik Montnemery 2025-07-22 14:35:57 +02:00 committed by GitHub
parent 5a771b501d
commit dd399ef59f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 256 additions and 118 deletions

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Mapping
import contextlib
from datetime import datetime, timedelta
from datetime import datetime
from errno import EHOSTUNREACH, EIO
import io
import logging
@ -52,9 +52,8 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import config_validation as cv, template as template_helper
from homeassistant.helpers.entity_platform import EntityPlatform
from homeassistant.helpers.entity_platform import PlatformData
from homeassistant.helpers.httpx_client import get_async_client
from homeassistant.setup import async_prepare_setup_platform
from homeassistant.util import slugify
from .camera import GenericCamera, generate_auth
@ -569,18 +568,9 @@ async def ws_start_preview(
)
user_input = flow.preview_image_settings
# Create an EntityPlatform, needed for name translations
platform = await async_prepare_setup_platform(hass, {}, CAMERA_DOMAIN, DOMAIN)
entity_platform = EntityPlatform(
hass=hass,
logger=_LOGGER,
domain=CAMERA_DOMAIN,
platform_name=DOMAIN,
platform=platform,
scan_interval=timedelta(seconds=3600),
entity_namespace=None,
)
await entity_platform.async_load_translations()
# Create PlatformData, needed for name translations
platform_data = PlatformData(hass=hass, domain=CAMERA_DOMAIN, platform_name=DOMAIN)
await platform_data.async_load_translations()
ha_still_url = None
ha_stream_url = None

View File

@ -387,7 +387,9 @@ class NumberEntity(Entity, cached_properties=CACHED_PROPERTIES_WITH_ATTR_):
if (translation_key := self._unit_of_measurement_translation_key) and (
unit_of_measurement
:= self.platform.default_language_platform_translations.get(translation_key)
:= self.platform_data.default_language_platform_translations.get(
translation_key
)
):
if native_unit_of_measurement is not None:
raise ValueError(

View File

@ -523,7 +523,9 @@ class SensorEntity(Entity, cached_properties=CACHED_PROPERTIES_WITH_ATTR_):
# Fourth priority: Unit translation
if (translation_key := self._unit_of_measurement_translation_key) and (
unit_of_measurement
:= self.platform.default_language_platform_translations.get(translation_key)
:= self.platform_data.default_language_platform_translations.get(
translation_key
)
):
if native_unit_of_measurement is not None:
raise ValueError(

View File

@ -3,7 +3,6 @@
from __future__ import annotations
from collections.abc import Mapping
from datetime import timedelta
import logging
from typing import Any
@ -12,7 +11,7 @@ import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import EntityPlatform
from homeassistant.helpers.entity_platform import PlatformData
from homeassistant.helpers.schema_config_entry_flow import (
SchemaCommonFlowHandler,
SchemaConfigFlowHandler,
@ -24,7 +23,6 @@ from homeassistant.helpers.selector import (
SelectSelectorConfig,
SelectSelectorMode,
)
from homeassistant.setup import async_prepare_setup_platform
from .const import CONF_DISPLAY_OPTIONS, DOMAIN, OPTION_TYPES
from .sensor import TimeDateSensor
@ -99,18 +97,9 @@ async def ws_start_preview(
"""Generate a preview."""
validated = USER_SCHEMA(msg["user_input"])
# Create an EntityPlatform, needed for name translations
platform = await async_prepare_setup_platform(hass, {}, SENSOR_DOMAIN, DOMAIN)
entity_platform = EntityPlatform(
hass=hass,
logger=_LOGGER,
domain=SENSOR_DOMAIN,
platform_name=DOMAIN,
platform=platform,
scan_interval=timedelta(seconds=3600),
entity_namespace=None,
)
await entity_platform.async_load_translations()
# Create PlatformData, needed for name translations
platform_data = PlatformData(hass=hass, domain=SENSOR_DOMAIN, platform_name=DOMAIN)
await platform_data.async_load_translations()
@callback
def async_preview_updated(state: str, attributes: Mapping[str, Any]) -> None:
@ -123,7 +112,7 @@ async def ws_start_preview(
preview_entity = TimeDateSensor(validated[CONF_DISPLAY_OPTIONS])
preview_entity.hass = hass
preview_entity.platform = entity_platform
preview_entity.platform_data = platform_data
connection.send_result(msg["id"])
connection.subscriptions[msg["id"]] = preview_entity.async_start_preview(

View File

@ -66,7 +66,7 @@ from .typing import UNDEFINED, StateType, UndefinedType
timer = time.time
if TYPE_CHECKING:
from .entity_platform import EntityPlatform
from .entity_platform import EntityPlatform, PlatformData
_LOGGER = logging.getLogger(__name__)
SLOW_UPDATE_WARNING = 10
@ -449,6 +449,7 @@ class Entity(
# While not purely typed, it makes typehinting more useful for us
# and removes the need for constant None checks or asserts.
platform: EntityPlatform = None # type: ignore[assignment]
platform_data: PlatformData = None # type: ignore[assignment]
# Entity description instance for this Entity
entity_description: EntityDescription
@ -593,7 +594,7 @@ class Entity(
return not self._attr_name
if (
name_translation_key := self._name_translation_key
) and name_translation_key in self.platform.platform_translations:
) and name_translation_key in self.platform_data.platform_translations:
return False
if hasattr(self, "entity_description"):
return not self.entity_description.name
@ -616,9 +617,9 @@ class Entity(
if not self.has_entity_name:
return None
device_class_key = self.device_class or "_"
platform = self.platform
platform_domain = self.platform_data.domain
name_translation_key = (
f"component.{platform.domain}.entity_component.{device_class_key}.name"
f"component.{platform_domain}.entity_component.{device_class_key}.name"
)
return component_translations.get(name_translation_key)
@ -626,13 +627,13 @@ class Entity(
def _object_id_device_class_name(self) -> str | None:
"""Return a translated name of the entity based on its device class."""
return self._device_class_name_helper(
self.platform.object_id_component_translations
self.platform_data.object_id_component_translations
)
@cached_property
def _device_class_name(self) -> str | None:
"""Return a translated name of the entity based on its device class."""
return self._device_class_name_helper(self.platform.component_translations)
return self._device_class_name_helper(self.platform_data.component_translations)
def _default_to_device_class_name(self) -> bool:
"""Return True if an unnamed entity should be named by its device class."""
@ -643,9 +644,9 @@ class Entity(
"""Return translation key for entity name."""
if self.translation_key is None:
return None
platform = self.platform
platform_data = self.platform_data
return (
f"component.{platform.platform_name}.entity.{platform.domain}"
f"component.{platform_data.platform_name}.entity.{platform_data.domain}"
f".{self.translation_key}.name"
)
@ -654,14 +655,14 @@ class Entity(
"""Return translation key for unit of measurement."""
if self.translation_key is None:
return None
if self.platform is None:
if self.platform_data is None:
raise ValueError(
f"Entity {type(self)} cannot have a translation key for "
"unit of measurement before being added to the entity platform"
)
platform = self.platform
platform_data = self.platform_data
return (
f"component.{platform.platform_name}.entity.{platform.domain}"
f"component.{platform_data.platform_name}.entity.{platform_data.domain}"
f".{self.translation_key}.unit_of_measurement"
)
@ -724,13 +725,13 @@ class Entity(
# value.
type.__getattribute__(self.__class__, "name")
is type.__getattribute__(Entity, "name")
# The check for self.platform guards against integrations not using an
# EntityComponent and can be removed in HA Core 2024.1
and self.platform
# The check for self.platform_data guards against integrations not using an
# EntityComponent and can be removed in HA Core 2026.8
and self.platform_data
):
name = self._name_internal(
self._object_id_device_class_name,
self.platform.object_id_platform_translations,
self.platform_data.object_id_platform_translations,
)
else:
name = self.name
@ -739,13 +740,13 @@ class Entity(
@cached_property
def name(self) -> str | UndefinedType | None:
"""Return the name of the entity."""
# The check for self.platform guards against integrations not using an
# EntityComponent and can be removed in HA Core 2024.1
if not self.platform:
# The check for self.platform_data guards against integrations not using an
# EntityComponent and can be removed in HA Core 2026.8
if not self.platform_data:
return self._name_internal(None, {})
return self._name_internal(
self._device_class_name,
self.platform.platform_translations,
self.platform_data.platform_translations,
)
@cached_property
@ -986,7 +987,7 @@ class Entity(
raise RuntimeError(f"Attribute hass is None for {self}")
# The check for self.platform guards against integrations not using an
# EntityComponent and can be removed in HA Core 2024.1
# EntityComponent and can be removed in HA Core 2026.8
if self.platform is None and not self._no_platform_reported: # type: ignore[unreachable]
report_issue = self._suggest_report_issue() # type: ignore[unreachable]
_LOGGER.warning(
@ -1351,6 +1352,7 @@ class Entity(
self.hass = hass
self.platform = platform
self.platform_data = platform.platform_data
self.parallel_updates = parallel_updates
self._platform_state = EntityPlatformState.ADDING
@ -1494,7 +1496,7 @@ class Entity(
Not to be extended by integrations.
"""
# The check for self.platform guards against integrations not using an
# EntityComponent and can be removed in HA Core 2024.1
# EntityComponent and can be removed in HA Core 2026.8
if self.platform:
del entity_sources(self.hass)[self.entity_id]
@ -1626,9 +1628,9 @@ class Entity(
def _suggest_report_issue(self) -> str:
"""Suggest to report an issue."""
# The check for self.platform guards against integrations not using an
# EntityComponent and can be removed in HA Core 2024.1
platform_name = self.platform.platform_name if self.platform else None
# The check for self.platform_data guards against integrations not using an
# EntityComponent and can be removed in HA Core 2026.8
platform_name = self.platform_data.platform_name if self.platform_data else None
return async_suggest_report_issue(
self.hass, integration_domain=platform_name, module=type(self).__module__
)

View File

@ -44,6 +44,7 @@ from . import (
service,
translation,
)
from .deprecation import deprecated_function
from .entity_registry import EntityRegistry, RegistryEntryDisabler, RegistryEntryHider
from .event import async_call_later
from .issue_registry import IssueSeverity, async_create_issue
@ -126,6 +127,77 @@ class EntityPlatformModule(Protocol):
"""Set up an integration platform from a config entry."""
class PlatformData:
"""Information about a platform, used by entities."""
def __init__(
self,
hass: HomeAssistant,
*,
domain: str,
platform_name: str,
) -> None:
"""Initialize the base entity platform."""
self.hass = hass
self.domain = domain
self.platform_name = platform_name
self.component_translations: dict[str, str] = {}
self.platform_translations: dict[str, str] = {}
self.object_id_component_translations: dict[str, str] = {}
self.object_id_platform_translations: dict[str, str] = {}
self.default_language_platform_translations: dict[str, str] = {}
async def _async_get_translations(
self, language: str, category: str, integration: str
) -> dict[str, str]:
"""Get translations for a language, category, and integration."""
try:
return await translation.async_get_translations(
self.hass, language, category, {integration}
)
except Exception as err: # noqa: BLE001
_LOGGER.debug(
"Could not load translations for %s",
integration,
exc_info=err,
)
return {}
async def async_load_translations(self) -> None:
"""Load translations."""
hass = self.hass
object_id_language = (
hass.config.language
if hass.config.language in languages.NATIVE_ENTITY_IDS
else languages.DEFAULT_LANGUAGE
)
config_language = hass.config.language
self.component_translations = await self._async_get_translations(
config_language, "entity_component", self.domain
)
self.platform_translations = await self._async_get_translations(
config_language, "entity", self.platform_name
)
if object_id_language == config_language:
self.object_id_component_translations = self.component_translations
self.object_id_platform_translations = self.platform_translations
else:
self.object_id_component_translations = await self._async_get_translations(
object_id_language, "entity_component", self.domain
)
self.object_id_platform_translations = await self._async_get_translations(
object_id_language, "entity", self.platform_name
)
if config_language == languages.DEFAULT_LANGUAGE:
self.default_language_platform_translations = self.platform_translations
else:
self.default_language_platform_translations = (
await self._async_get_translations(
languages.DEFAULT_LANGUAGE, "entity", self.platform_name
)
)
class EntityPlatform:
"""Manage the entities for a single platform.
@ -147,8 +219,6 @@ class EntityPlatform:
"""Initialize the entity platform."""
self.hass = hass
self.logger = logger
self.domain = domain
self.platform_name = platform_name
self.platform = platform
self.scan_interval = scan_interval
self.scan_interval_seconds = scan_interval.total_seconds()
@ -157,11 +227,6 @@ class EntityPlatform:
# Storage for entities for this specific platform only
# which are indexed by entity_id
self.entities: dict[str, Entity] = {}
self.component_translations: dict[str, str] = {}
self.platform_translations: dict[str, str] = {}
self.object_id_component_translations: dict[str, str] = {}
self.object_id_platform_translations: dict[str, str] = {}
self.default_language_platform_translations: dict[str, str] = {}
self._tasks: list[asyncio.Task[None]] = []
# Stop tracking tasks after setup is completed
self._setup_complete = False
@ -195,6 +260,10 @@ class EntityPlatform:
DATA_DOMAIN_PLATFORM_ENTITIES, {}
).setdefault(key, {})
self.platform_data = PlatformData(
hass, domain=domain, platform_name=platform_name
)
def __repr__(self) -> str:
"""Represent an EntityPlatform."""
return (
@ -362,7 +431,7 @@ class EntityPlatform:
hass = self.hass
full_name = f"{self.platform_name}.{self.domain}"
await self.async_load_translations()
await self.platform_data.async_load_translations()
logger.info("Setting up %s", full_name)
warn_task = hass.loop.call_at(
@ -457,56 +526,6 @@ class EntityPlatform:
finally:
warn_task.cancel()
async def _async_get_translations(
self, language: str, category: str, integration: str
) -> dict[str, str]:
"""Get translations for a language, category, and integration."""
try:
return await translation.async_get_translations(
self.hass, language, category, {integration}
)
except Exception as err: # noqa: BLE001
_LOGGER.debug(
"Could not load translations for %s",
integration,
exc_info=err,
)
return {}
async def async_load_translations(self) -> None:
"""Load translations."""
hass = self.hass
object_id_language = (
hass.config.language
if hass.config.language in languages.NATIVE_ENTITY_IDS
else languages.DEFAULT_LANGUAGE
)
config_language = hass.config.language
self.component_translations = await self._async_get_translations(
config_language, "entity_component", self.domain
)
self.platform_translations = await self._async_get_translations(
config_language, "entity", self.platform_name
)
if object_id_language == config_language:
self.object_id_component_translations = self.component_translations
self.object_id_platform_translations = self.platform_translations
else:
self.object_id_component_translations = await self._async_get_translations(
object_id_language, "entity_component", self.domain
)
self.object_id_platform_translations = await self._async_get_translations(
object_id_language, "entity", self.platform_name
)
if config_language == languages.DEFAULT_LANGUAGE:
self.default_language_platform_translations = self.platform_translations
else:
self.default_language_platform_translations = (
await self._async_get_translations(
languages.DEFAULT_LANGUAGE, "entity", self.platform_name
)
)
def _schedule_add_entities(
self, new_entities: Iterable[Entity], update_before_add: bool = False
) -> None:
@ -1120,6 +1139,87 @@ class EntityPlatform:
]:
await asyncio.gather(*tasks)
@property
def domain(self) -> str:
"""Return the domain (e.g. light)."""
return self.platform_data.domain
@property
def platform_name(self) -> str:
"""Return the platform name (e.g hue)."""
return self.platform_data.platform_name
@property
@deprecated_function(
"platform_data.component_translations",
breaks_in_ha_version="2026.8",
)
def component_translations(self) -> dict[str, str]:
"""Return the component translations.
Will be removed in Home Assistant Core 2026.8.
"""
return self.platform_data.component_translations
@property
@deprecated_function(
"platform_data.platform_translations",
breaks_in_ha_version="2026.8",
)
def platform_translations(self) -> dict[str, str]:
"""Return the platform translations.
Will be removed in Home Assistant Core 2026.8.
"""
return self.platform_data.platform_translations
@property
@deprecated_function(
"platform_data.object_id_component_translations",
breaks_in_ha_version="2026.8",
)
def object_id_component_translations(self) -> dict[str, str]:
"""Return the object ID component translations.
Will be removed in Home Assistant Core 2026.8.
"""
return self.platform_data.object_id_component_translations
@property
@deprecated_function(
"platform_data.object_id_platform_translations",
breaks_in_ha_version="2026.8",
)
def object_id_platform_translations(self) -> dict[str, str]:
"""Return the object ID platform translations.
Will be removed in Home Assistant Core 2026.8.
"""
return self.platform_data.object_id_platform_translations
@property
@deprecated_function(
"platform_data.default_language_platform_translations",
breaks_in_ha_version="2026.8",
)
def default_language_platform_translations(self) -> dict[str, str]:
"""Return the default language platform translations.
Will be removed in Home Assistant Core 2026.8.
"""
return self.platform_data.default_language_platform_translations
@deprecated_function(
"platform_data.async_load_translations",
breaks_in_ha_version="2026.8",
)
async def async_load_translations(self) -> None:
"""Load translations.
Will be removed in Home Assistant Core 2026.8.
"""
return await self.platform_data.async_load_translations()
@callback
def async_calculate_suggested_object_id(

View File

@ -685,7 +685,7 @@ async def test_generic_workaround(
rest_client.get_jpeg_snapshot.return_value = image_bytes
camera.set_stream_source("https://my_stream_url.m3u8")
with patch.object(camera.platform, "platform_name", "generic"):
with patch.object(camera.platform.platform_data, "platform_name", "generic"):
image = await async_get_image(hass, camera.entity_id)
assert image.content == image_bytes

View File

@ -781,7 +781,7 @@ async def test_warn_slow_write_state(
mock_entity = entity.Entity()
mock_entity.hass = hass
mock_entity.entity_id = "comp_test.test_entity"
mock_entity.platform = MagicMock(platform_name="hue")
mock_entity.platform_data = MagicMock(platform_name="hue")
mock_entity._platform_state = entity.EntityPlatformState.ADDED
with patch("homeassistant.helpers.entity.timer", side_effect=[0, 10]):
@ -809,7 +809,7 @@ async def test_warn_slow_write_state_custom_component(
mock_entity = CustomComponentEntity()
mock_entity.hass = hass
mock_entity.entity_id = "comp_test.test_entity"
mock_entity.platform = MagicMock(platform_name="hue")
mock_entity.platform_data = MagicMock(platform_name="hue")
mock_entity._platform_state = entity.EntityPlatformState.ADDED
with patch("homeassistant.helpers.entity.timer", side_effect=[0, 10]):

View File

@ -2447,3 +2447,56 @@ async def test_add_entity_unknown_subentry(
"Can't add entities to unknown subentry unknown-subentry "
"of config entry super-mock-id"
) in caplog.text
@pytest.mark.parametrize("integration_frame_path", ["custom_components/my_integration"])
@pytest.mark.usefixtures("mock_integration_frame")
@pytest.mark.parametrize(
"deprecated_attribute",
[
"component_translations",
"platform_translations",
"object_id_component_translations",
"object_id_platform_translations",
"default_language_platform_translations",
],
)
async def test_deprecated_attributes(
hass: HomeAssistant,
deprecated_attribute: str,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test setting the device name based on input info."""
platform = MockPlatform()
entity_platform = MockEntityPlatform(hass, platform_name="test", platform=platform)
assert getattr(entity_platform, deprecated_attribute) is getattr(
entity_platform.platform_data, deprecated_attribute
)
assert (
f"The deprecated function {deprecated_attribute} was called from "
"my_integration. It will be removed in HA Core 2026.8. Use platform_data."
f"{deprecated_attribute} instead, please report it to the author of the "
"'my_integration' custom integration" in caplog.text
)
@pytest.mark.parametrize("integration_frame_path", ["custom_components/my_integration"])
@pytest.mark.usefixtures("mock_integration_frame")
async def test_deprecated_async_load_translations(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test setting the device name based on input info."""
platform = MockPlatform()
entity_platform = MockEntityPlatform(hass, platform_name="test", platform=platform)
await entity_platform.async_load_translations()
assert (
"The deprecated function async_load_translations was called from "
"my_integration. It will be removed in HA Core 2026.8. Use platform_data."
"async_load_translations instead, please report it to the author of the "
"'my_integration' custom integration" in caplog.text
)