This commit is contained in:
J. Nick Koston
2025-06-24 22:57:10 +02:00
parent 48f2911434
commit 5ad1af69e4
24 changed files with 269 additions and 167 deletions

View File

@@ -14,8 +14,8 @@ from esphome.const import (
CONF_WEB_SERVER,
)
from esphome.core import CORE, coroutine_with_priority
from esphome.core.entity_helpers import entity_duplicate_validator, setup_entity
from esphome.cpp_generator import MockObjClass
from esphome.cpp_helpers import setup_entity
CODEOWNERS = ["@grahambrown11", "@hwstar"]
IS_PLATFORM_COMPONENT = True
@@ -149,6 +149,10 @@ _ALARM_CONTROL_PANEL_SCHEMA = (
)
# Add duplicate entity validation
_ALARM_CONTROL_PANEL_SCHEMA.add_extra(entity_duplicate_validator("alarm_control_panel"))
def alarm_control_panel_schema(
class_: MockObjClass,
*,

View File

@@ -60,8 +60,8 @@ from esphome.const import (
DEVICE_CLASS_WINDOW,
)
from esphome.core import CORE, coroutine_with_priority
from esphome.core.entity_helpers import entity_duplicate_validator, setup_entity
from esphome.cpp_generator import MockObjClass
from esphome.cpp_helpers import setup_entity
from esphome.util import Registry
CODEOWNERS = ["@esphome/core"]
@@ -491,6 +491,10 @@ _BINARY_SENSOR_SCHEMA = (
)
# Add duplicate entity validation
_BINARY_SENSOR_SCHEMA.add_extra(entity_duplicate_validator("binary_sensor"))
def binary_sensor_schema(
class_: MockObjClass = cv.UNDEFINED,
*,

View File

@@ -18,8 +18,8 @@ from esphome.const import (
DEVICE_CLASS_UPDATE,
)
from esphome.core import CORE, coroutine_with_priority
from esphome.core.entity_helpers import entity_duplicate_validator, setup_entity
from esphome.cpp_generator import MockObjClass
from esphome.cpp_helpers import setup_entity
CODEOWNERS = ["@esphome/core"]
IS_PLATFORM_COMPONENT = True
@@ -61,6 +61,10 @@ _BUTTON_SCHEMA = (
)
# Add duplicate entity validation
_BUTTON_SCHEMA.add_extra(entity_duplicate_validator("button"))
def button_schema(
class_: MockObjClass,
*,

View File

@@ -48,8 +48,8 @@ from esphome.const import (
CONF_WEB_SERVER,
)
from esphome.core import CORE, coroutine_with_priority
from esphome.core.entity_helpers import entity_duplicate_validator, setup_entity
from esphome.cpp_generator import MockObjClass
from esphome.cpp_helpers import setup_entity
IS_PLATFORM_COMPONENT = True
@@ -247,6 +247,10 @@ _CLIMATE_SCHEMA = (
)
# Add duplicate entity validation
_CLIMATE_SCHEMA.add_extra(entity_duplicate_validator("climate"))
def climate_schema(
class_: MockObjClass,
*,

View File

@@ -33,8 +33,8 @@ from esphome.const import (
DEVICE_CLASS_WINDOW,
)
from esphome.core import CORE, coroutine_with_priority
from esphome.core.entity_helpers import entity_duplicate_validator, setup_entity
from esphome.cpp_generator import MockObjClass
from esphome.cpp_helpers import setup_entity
IS_PLATFORM_COMPONENT = True
@@ -126,6 +126,10 @@ _COVER_SCHEMA = (
)
# Add duplicate entity validation
_COVER_SCHEMA.add_extra(entity_duplicate_validator("cover"))
def cover_schema(
class_: MockObjClass,
*,

View File

@@ -22,8 +22,8 @@ from esphome.const import (
CONF_YEAR,
)
from esphome.core import CORE, coroutine_with_priority
from esphome.core.entity_helpers import entity_duplicate_validator, setup_entity
from esphome.cpp_generator import MockObjClass
from esphome.cpp_helpers import setup_entity
CODEOWNERS = ["@rfdarter", "@jesserockz"]
@@ -84,6 +84,9 @@ _DATETIME_SCHEMA = cv.ENTITY_BASE_SCHEMA.extend(
.extend(cv.MQTT_COMMAND_COMPONENT_SCHEMA)
).add_extra(_validate_time_present)
# Add duplicate entity validation
_DATETIME_SCHEMA.add_extra(entity_duplicate_validator("datetime"))
def date_schema(class_: MockObjClass) -> cv.Schema:
schema = cv.Schema(

View File

@@ -19,7 +19,7 @@ from esphome.const import (
CONF_VSYNC_PIN,
)
from esphome.core import CORE
from esphome.cpp_helpers import setup_entity
from esphome.core.entity_helpers import setup_entity
DEPENDENCIES = ["esp32"]

View File

@@ -18,8 +18,8 @@ from esphome.const import (
DEVICE_CLASS_MOTION,
)
from esphome.core import CORE, coroutine_with_priority
from esphome.core.entity_helpers import entity_duplicate_validator, setup_entity
from esphome.cpp_generator import MockObjClass
from esphome.cpp_helpers import setup_entity
CODEOWNERS = ["@nohat"]
IS_PLATFORM_COMPONENT = True
@@ -59,6 +59,10 @@ _EVENT_SCHEMA = (
)
# Add duplicate entity validation
_EVENT_SCHEMA.add_extra(entity_duplicate_validator("event"))
def event_schema(
class_: MockObjClass = cv.UNDEFINED,
*,

View File

@@ -32,7 +32,7 @@ from esphome.const import (
CONF_WEB_SERVER,
)
from esphome.core import CORE, coroutine_with_priority
from esphome.cpp_helpers import setup_entity
from esphome.core.entity_helpers import entity_duplicate_validator, setup_entity
IS_PLATFORM_COMPONENT = True
@@ -161,6 +161,10 @@ _FAN_SCHEMA = (
)
# Add duplicate entity validation
_FAN_SCHEMA.add_extra(entity_duplicate_validator("fan"))
def fan_schema(
class_: cg.Pvariable,
*,

View File

@@ -38,8 +38,8 @@ from esphome.const import (
CONF_WHITE,
)
from esphome.core import CORE, coroutine_with_priority
from esphome.core.entity_helpers import entity_duplicate_validator, setup_entity
from esphome.cpp_generator import MockObjClass
from esphome.cpp_helpers import setup_entity
from .automation import LIGHT_STATE_SCHEMA
from .effects import (
@@ -110,6 +110,9 @@ LIGHT_SCHEMA = (
)
)
# Add duplicate entity validation
LIGHT_SCHEMA.add_extra(entity_duplicate_validator("light"))
BINARY_LIGHT_SCHEMA = LIGHT_SCHEMA.extend(
{
cv.Optional(CONF_EFFECTS): validate_effects(BINARY_EFFECTS),

View File

@@ -14,8 +14,8 @@ from esphome.const import (
CONF_WEB_SERVER,
)
from esphome.core import CORE, coroutine_with_priority
from esphome.core.entity_helpers import entity_duplicate_validator, setup_entity
from esphome.cpp_generator import MockObjClass
from esphome.cpp_helpers import setup_entity
CODEOWNERS = ["@esphome/core"]
IS_PLATFORM_COMPONENT = True
@@ -67,6 +67,10 @@ _LOCK_SCHEMA = (
)
# Add duplicate entity validation
_LOCK_SCHEMA.add_extra(entity_duplicate_validator("lock"))
def lock_schema(
class_: MockObjClass = cv.UNDEFINED,
*,

View File

@@ -11,9 +11,9 @@ from esphome.const import (
CONF_VOLUME,
)
from esphome.core import CORE
from esphome.core.entity_helpers import entity_duplicate_validator, setup_entity
from esphome.coroutine import coroutine_with_priority
from esphome.cpp_generator import MockObjClass
from esphome.cpp_helpers import setup_entity
CODEOWNERS = ["@jesserockz"]
@@ -143,6 +143,9 @@ _MEDIA_PLAYER_SCHEMA = cv.ENTITY_BASE_SCHEMA.extend(
}
)
# Add duplicate entity validation
_MEDIA_PLAYER_SCHEMA.add_extra(entity_duplicate_validator("media_player"))
def media_player_schema(
class_: MockObjClass,
@@ -166,7 +169,6 @@ def media_player_schema(
MEDIA_PLAYER_SCHEMA = media_player_schema(MediaPlayer)
MEDIA_PLAYER_SCHEMA.add_extra(cv.deprecated_schema_constant("media_player"))
MEDIA_PLAYER_ACTION_SCHEMA = automation.maybe_simple_id(
cv.Schema(
{

View File

@@ -76,8 +76,8 @@ from esphome.const import (
DEVICE_CLASS_WIND_SPEED,
)
from esphome.core import CORE, coroutine_with_priority
from esphome.core.entity_helpers import entity_duplicate_validator, setup_entity
from esphome.cpp_generator import MockObjClass
from esphome.cpp_helpers import setup_entity
CODEOWNERS = ["@esphome/core"]
DEVICE_CLASSES = [
@@ -207,6 +207,10 @@ _NUMBER_SCHEMA = (
)
# Add duplicate entity validation
_NUMBER_SCHEMA.add_extra(entity_duplicate_validator("number"))
def number_schema(
class_: MockObjClass,
*,

View File

@@ -17,8 +17,8 @@ from esphome.const import (
CONF_WEB_SERVER,
)
from esphome.core import CORE, coroutine_with_priority
from esphome.core.entity_helpers import entity_duplicate_validator, setup_entity
from esphome.cpp_generator import MockObjClass
from esphome.cpp_helpers import setup_entity
CODEOWNERS = ["@esphome/core"]
IS_PLATFORM_COMPONENT = True
@@ -65,6 +65,10 @@ _SELECT_SCHEMA = (
)
# Add duplicate entity validation
_SELECT_SCHEMA.add_extra(entity_duplicate_validator("select"))
def select_schema(
class_: MockObjClass,
*,

View File

@@ -101,8 +101,8 @@ from esphome.const import (
ENTITY_CATEGORY_CONFIG,
)
from esphome.core import CORE, coroutine_with_priority
from esphome.core.entity_helpers import entity_duplicate_validator, setup_entity
from esphome.cpp_generator import MockObjClass
from esphome.cpp_helpers import setup_entity
from esphome.util import Registry
CODEOWNERS = ["@esphome/core"]
@@ -318,6 +318,9 @@ _SENSOR_SCHEMA = (
)
)
# Add duplicate entity validation
_SENSOR_SCHEMA.add_extra(entity_duplicate_validator("sensor"))
def sensor_schema(
class_: MockObjClass = cv.UNDEFINED,

View File

@@ -20,8 +20,8 @@ from esphome.const import (
DEVICE_CLASS_SWITCH,
)
from esphome.core import CORE, coroutine_with_priority
from esphome.core.entity_helpers import entity_duplicate_validator, setup_entity
from esphome.cpp_generator import MockObjClass
from esphome.cpp_helpers import setup_entity
CODEOWNERS = ["@esphome/core"]
IS_PLATFORM_COMPONENT = True
@@ -91,6 +91,10 @@ _SWITCH_SCHEMA = (
)
# Add duplicate entity validation
_SWITCH_SCHEMA.add_extra(entity_duplicate_validator("switch"))
def switch_schema(
class_: MockObjClass,
*,

View File

@@ -14,8 +14,8 @@ from esphome.const import (
CONF_WEB_SERVER,
)
from esphome.core import CORE, coroutine_with_priority
from esphome.core.entity_helpers import entity_duplicate_validator, setup_entity
from esphome.cpp_generator import MockObjClass
from esphome.cpp_helpers import setup_entity
CODEOWNERS = ["@mauritskorse"]
IS_PLATFORM_COMPONENT = True
@@ -58,6 +58,10 @@ _TEXT_SCHEMA = (
)
# Add duplicate entity validation
_TEXT_SCHEMA.add_extra(entity_duplicate_validator("text"))
def text_schema(
class_: MockObjClass = cv.UNDEFINED,
*,

View File

@@ -21,8 +21,8 @@ from esphome.const import (
DEVICE_CLASS_TIMESTAMP,
)
from esphome.core import CORE, coroutine_with_priority
from esphome.core.entity_helpers import entity_duplicate_validator, setup_entity
from esphome.cpp_generator import MockObjClass
from esphome.cpp_helpers import setup_entity
from esphome.util import Registry
DEVICE_CLASSES = [
@@ -153,6 +153,10 @@ _TEXT_SENSOR_SCHEMA = (
)
# Add duplicate entity validation
_TEXT_SENSOR_SCHEMA.add_extra(entity_duplicate_validator("text_sensor"))
def text_sensor_schema(
class_: MockObjClass = cv.UNDEFINED,
*,

View File

@@ -15,8 +15,8 @@ from esphome.const import (
ENTITY_CATEGORY_CONFIG,
)
from esphome.core import CORE, coroutine_with_priority
from esphome.core.entity_helpers import entity_duplicate_validator, setup_entity
from esphome.cpp_generator import MockObjClass
from esphome.cpp_helpers import setup_entity
CODEOWNERS = ["@jesserockz"]
IS_PLATFORM_COMPONENT = True
@@ -58,6 +58,10 @@ _UPDATE_SCHEMA = (
)
# Add duplicate entity validation
_UPDATE_SCHEMA.add_extra(entity_duplicate_validator("update"))
def update_schema(
class_: MockObjClass = cv.UNDEFINED,
*,

View File

@@ -22,8 +22,8 @@ from esphome.const import (
DEVICE_CLASS_WATER,
)
from esphome.core import CORE, coroutine_with_priority
from esphome.core.entity_helpers import entity_duplicate_validator, setup_entity
from esphome.cpp_generator import MockObjClass
from esphome.cpp_helpers import setup_entity
IS_PLATFORM_COMPONENT = True
@@ -103,6 +103,10 @@ _VALVE_SCHEMA = (
)
# Add duplicate entity validation
_VALVE_SCHEMA.add_extra(entity_duplicate_validator("valve"))
def valve_schema(
class_: MockObjClass = cv.UNDEFINED,
*,

View File

@@ -1,5 +1,115 @@
from esphome.const import CONF_ID
from collections.abc import Callable
import logging
from esphome.const import (
CONF_DEVICE_ID,
CONF_DISABLED_BY_DEFAULT,
CONF_ENTITY_CATEGORY,
CONF_ICON,
CONF_ID,
CONF_INTERNAL,
CONF_NAME,
)
from esphome.core import CORE, ID
from esphome.cpp_generator import MockObj, add, get_variable
import esphome.final_validate as fv
from esphome.helpers import sanitize, snake_case
from esphome.types import ConfigType
_LOGGER = logging.getLogger(__name__)
def get_base_entity_object_id(
name: str, friendly_name: str | None, device_name: str | None = None
) -> str:
"""Calculate the base object ID for an entity that will be set via set_object_id().
This function calculates what object_id_c_str_ should be set to in C++.
The C++ EntityBase::get_object_id() (entity_base.cpp lines 38-49) works as:
- If !has_own_name && is_name_add_mac_suffix_enabled():
return str_sanitize(str_snake_case(App.get_friendly_name())) // Dynamic
- Else:
return object_id_c_str_ ?? "" // What we set via set_object_id()
Since we're calculating what to pass to set_object_id(), we always need to
generate the object_id the same way, regardless of name_add_mac_suffix setting.
Args:
name: The entity name (empty string if no name)
friendly_name: The friendly name from CORE.friendly_name
device_name: The device name if entity is on a sub-device
Returns:
The base object ID to use for duplicate checking and to pass to set_object_id()
"""
if name:
# Entity has its own name (has_own_name will be true)
base_str = name
elif device_name:
# Entity has empty name and is on a sub-device
# C++ EntityBase::set_name() uses device->get_name() when device is set
base_str = device_name
elif friendly_name:
# Entity has empty name (has_own_name will be false)
# C++ uses App.get_friendly_name() which returns friendly_name or device name
base_str = friendly_name
else:
# Fallback to device name
base_str = CORE.name
return sanitize(snake_case(base_str))
async def setup_entity(var: MockObj, config: ConfigType, platform: str) -> None:
"""Set up generic properties of an Entity.
This function sets up the common entity properties like name, icon,
entity category, etc.
Args:
var: The entity variable to set up
config: Configuration dictionary containing entity settings
platform: The platform name (e.g., "sensor", "binary_sensor")
"""
# Get device info
device_name: str | None = None
if CONF_DEVICE_ID in config:
device_id_obj: ID = config[CONF_DEVICE_ID]
device: MockObj = await get_variable(device_id_obj)
add(var.set_device(device))
# Get device name for object ID calculation
device_name = device_id_obj.id
add(var.set_name(config[CONF_NAME]))
# Calculate base object_id using the same logic as C++
# This must match the C++ behavior in esphome/core/entity_base.cpp
base_object_id = get_base_entity_object_id(
config[CONF_NAME], CORE.friendly_name, device_name
)
if not config[CONF_NAME]:
_LOGGER.debug(
"Entity has empty name, using '%s' as object_id base", base_object_id
)
# Set the object ID
add(var.set_object_id(base_object_id))
_LOGGER.debug(
"Setting object_id '%s' for entity '%s' on platform '%s'",
base_object_id,
config[CONF_NAME],
platform,
)
add(var.set_disabled_by_default(config[CONF_DISABLED_BY_DEFAULT]))
if CONF_INTERNAL in config:
add(var.set_internal(config[CONF_INTERNAL]))
if CONF_ICON in config:
add(var.set_icon(config[CONF_ICON]))
if CONF_ENTITY_CATEGORY in config:
add(var.set_entity_category(config[CONF_ENTITY_CATEGORY]))
def inherit_property_from(property_to_inherit, parent_id_property, transform=None):
@@ -54,3 +164,60 @@ def inherit_property_from(property_to_inherit, parent_id_property, transform=Non
return config
return inherit_property
def entity_duplicate_validator(platform: str) -> Callable[[ConfigType], ConfigType]:
"""Create a validator function to check for duplicate entity names.
This validator is meant to be used with schema.add_extra() for entity base schemas.
Args:
platform: The platform name (e.g., "sensor", "binary_sensor")
Returns:
A validator function that checks for duplicate names
"""
def validator(config: ConfigType) -> ConfigType:
if CONF_NAME not in config:
# No name to validate
return config
# Get the entity name and device info
entity_name = config[CONF_NAME]
device_id = 0 # Main device by default
device_name = None
if CONF_DEVICE_ID in config:
device_config = config[CONF_DEVICE_ID]
if hasattr(device_config, "id"):
device_id = hash(device_config.id)
# Try to get device name from CORE if available
for dev in getattr(CORE, "devices", []):
if hasattr(dev, "id") and dev.id == device_config.id:
device_name = getattr(dev, "name", None)
break
# Calculate the base object ID
base_object_id = get_base_entity_object_id(
entity_name, CORE.friendly_name, device_name
)
# Check for duplicates
unique_key = (device_id, platform, base_object_id)
if unique_key in CORE.unique_ids:
# Import here to avoid circular dependency
import esphome.config_validation as cv
entity_name_display = entity_name or base_object_id
device_prefix = f" on device '{device_name}'" if device_name else ""
raise cv.Invalid(
f"Duplicate {platform} entity with name '{entity_name_display}' found{device_prefix}. "
f"Each entity on a device must have a unique name within its platform."
)
# Add to tracking set
CORE.unique_ids.add(unique_key)
return config
return validator

View File

@@ -11,9 +11,6 @@ from esphome.core import CORE, ID, coroutine
from esphome.coroutine import FakeAwaitable
from esphome.cpp_generator import add, get_variable
from esphome.cpp_types import App
from esphome.entity import ( # noqa: F401 # pylint: disable=unused-import
setup_entity, # Import for backward compatibility
)
from esphome.types import ConfigFragmentType, ConfigType
from esphome.util import Registry, RegistryEntry

View File

@@ -1,132 +0,0 @@
"""Entity-related helper functions."""
import logging
from esphome.const import (
CONF_DEVICE_ID,
CONF_DISABLED_BY_DEFAULT,
CONF_ENTITY_CATEGORY,
CONF_ICON,
CONF_INTERNAL,
CONF_NAME,
)
from esphome.core import CORE, ID
from esphome.cpp_generator import MockObj, add, get_variable
from esphome.helpers import fnv1a_32bit_hash, sanitize, snake_case
from esphome.types import ConfigType
_LOGGER = logging.getLogger(__name__)
def get_base_entity_object_id(
name: str, friendly_name: str | None, device_name: str | None = None
) -> str:
"""Calculate the base object ID for an entity that will be set via set_object_id().
This function calculates what object_id_c_str_ should be set to in C++.
The C++ EntityBase::get_object_id() (entity_base.cpp lines 38-49) works as:
- If !has_own_name && is_name_add_mac_suffix_enabled():
return str_sanitize(str_snake_case(App.get_friendly_name())) // Dynamic
- Else:
return object_id_c_str_ ?? "" // What we set via set_object_id()
Since we're calculating what to pass to set_object_id(), we always need to
generate the object_id the same way, regardless of name_add_mac_suffix setting.
Args:
name: The entity name (empty string if no name)
friendly_name: The friendly name from CORE.friendly_name
device_name: The device name if entity is on a sub-device
Returns:
The base object ID to use for duplicate checking and to pass to set_object_id()
"""
if name:
# Entity has its own name (has_own_name will be true)
base_str = name
elif device_name:
# Entity has empty name and is on a sub-device
# C++ EntityBase::set_name() uses device->get_name() when device is set
base_str = device_name
elif friendly_name:
# Entity has empty name (has_own_name will be false)
# C++ uses App.get_friendly_name() which returns friendly_name or device name
base_str = friendly_name
else:
# Fallback to device name
base_str = CORE.name
return sanitize(snake_case(base_str))
async def setup_entity(var: MockObj, config: ConfigType, platform: str) -> None:
"""Set up generic properties of an Entity.
This function handles duplicate entity names by automatically appending
a suffix (_2, _3, etc.) when multiple entities have the same object_id
within the same platform and device combination.
Args:
var: The entity variable to set up
config: Configuration dictionary containing entity settings
platform: The platform name (e.g., "sensor", "binary_sensor")
"""
# Get device info
device_id: int = 0
device_name: str | None = None
if CONF_DEVICE_ID in config:
device_id_obj: ID = config[CONF_DEVICE_ID]
device: MockObj = await get_variable(device_id_obj)
add(var.set_device(device))
# Use the device's ID hash as device_id
device_id = fnv1a_32bit_hash(device_id_obj.id)
# Get device name for object ID calculation
device_name = device_id_obj.id
add(var.set_name(config[CONF_NAME]))
# Calculate base object_id using the same logic as C++
# This must match the C++ behavior in esphome/core/entity_base.cpp
base_object_id = get_base_entity_object_id(
config[CONF_NAME], CORE.friendly_name, device_name
)
if not config[CONF_NAME]:
_LOGGER.debug(
"Entity has empty name, using '%s' as object_id base", base_object_id
)
# Check for duplicates
unique_key: tuple[int, str, str] = (device_id, platform, base_object_id)
if unique_key in CORE.unique_ids:
# Found duplicate - fail validation
from esphome.config_validation import Invalid
entity_name = config[CONF_NAME] or base_object_id
device_prefix = f" on device '{device_name}'" if device_name else ""
raise Invalid(
f"Duplicate {platform} entity with name '{entity_name}' found{device_prefix}. "
f"Each entity on a device must have a unique name within its platform."
)
else:
# First occurrence - register it
CORE.unique_ids.add(unique_key)
object_id = base_object_id
add(var.set_object_id(object_id))
_LOGGER.debug(
"Setting object_id '%s' for entity '%s' on platform '%s'",
object_id,
config[CONF_NAME],
platform,
)
add(var.set_disabled_by_default(config[CONF_DISABLED_BY_DEFAULT]))
if CONF_INTERNAL in config:
add(var.set_internal(config[CONF_INTERNAL]))
if CONF_ICON in config:
add(var.set_icon(config[CONF_ICON]))
if CONF_ENTITY_CATEGORY in config:
add(var.set_entity_category(config[CONF_ENTITY_CATEGORY]))

View File

@@ -6,12 +6,11 @@ from typing import Any
import pytest
from esphome import entity
from esphome.config_validation import Invalid
from esphome.const import CONF_DEVICE_ID, CONF_DISABLED_BY_DEFAULT, CONF_ICON, CONF_NAME
from esphome.core import CORE, ID
from esphome.core import CORE, ID, entity_helpers
from esphome.core.entity_helpers import get_base_entity_object_id, setup_entity
from esphome.cpp_generator import MockObj
from esphome.entity import get_base_entity_object_id, setup_entity
from esphome.helpers import sanitize, snake_case
# Pre-compiled regex pattern for extracting object IDs from expressions
@@ -240,7 +239,7 @@ def setup_test_environment() -> Generator[list[str], None, None]:
CORE.friendly_name = "Test Device"
# Store original add function
original_add = entity.add
original_add = entity_helpers.add
# Track what gets added
added_expressions: list[str] = []
@@ -248,11 +247,11 @@ def setup_test_environment() -> Generator[list[str], None, None]:
added_expressions.append(str(expression))
return original_add(expression)
# Patch add function in entity module
entity.add = mock_add
# Patch add function in entity_helpers module
entity_helpers.add = mock_add
yield added_expressions
# Clean up
entity.add = original_add
entity_helpers.add = original_add
def extract_object_id_from_expressions(expressions: list[str]) -> str | None:
@@ -372,17 +371,17 @@ async def test_setup_entity_different_platforms(
def mock_get_variable() -> Generator[dict[ID, MockObj], None, None]:
"""Mock get_variable to return test devices."""
devices = {}
original_get_variable = entity.get_variable
original_get_variable = entity_helpers.get_variable
async def _mock_get_variable(device_id: ID) -> MockObj:
if device_id in devices:
return devices[device_id]
return await original_get_variable(device_id)
entity.get_variable = _mock_get_variable
entity_helpers.get_variable = _mock_get_variable
yield devices
# Clean up
entity.get_variable = original_get_variable
entity_helpers.get_variable = original_get_variable
@pytest.mark.asyncio