Move state length validation to StateMachine APIs (#143681)

* Move state length validation to StateMachine async_set method

We call validate_state to make sure we do not allow any states
into the state machine that have a length>255 so we do not break
the recorder. Since async_set_internal already requires callers
to pre-validate the state, we can move the check to async_set
instead of at State object creation time to avoid needing to
check it twice in the hot path (entity write state)

* move check in async_set_internal so it only happens on state change

* no need to check if same_state
This commit is contained in:
J. Nick Koston 2025-04-25 15:15:15 -10:00 committed by GitHub
parent 03950f270a
commit 34d17ca458
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 51 additions and 22 deletions

View File

@ -72,6 +72,7 @@ from .const import (
MAX_EXPECTED_ENTITY_IDS, MAX_EXPECTED_ENTITY_IDS,
MAX_LENGTH_EVENT_EVENT_TYPE, MAX_LENGTH_EVENT_EVENT_TYPE,
MAX_LENGTH_STATE_STATE, MAX_LENGTH_STATE_STATE,
STATE_UNKNOWN,
__version__, __version__,
) )
from .exceptions import ( from .exceptions import (
@ -1794,18 +1795,13 @@ class State:
) -> None: ) -> None:
"""Initialize a new state.""" """Initialize a new state."""
self._cache: dict[str, Any] = {} self._cache: dict[str, Any] = {}
state = str(state)
if validate_entity_id and not valid_entity_id(entity_id): if validate_entity_id and not valid_entity_id(entity_id):
raise InvalidEntityFormatError( raise InvalidEntityFormatError(
f"Invalid entity id encountered: {entity_id}. " f"Invalid entity id encountered: {entity_id}. "
"Format should be <domain>.<object_id>" "Format should be <domain>.<object_id>"
) )
validate_state(state)
self.entity_id = entity_id self.entity_id = entity_id
self.state = state self.state = state if type(state) is str else str(state)
# State only creates and expects a ReadOnlyDict so # State only creates and expects a ReadOnlyDict so
# there is no need to check for subclassing with # there is no need to check for subclassing with
# isinstance here so we can use the faster type check. # isinstance here so we can use the faster type check.
@ -2270,9 +2266,11 @@ class StateMachine:
This method must be run in the event loop. This method must be run in the event loop.
""" """
state = str(new_state)
validate_state(state)
self.async_set_internal( self.async_set_internal(
entity_id.lower(), entity_id.lower(),
str(new_state), state,
attributes or {}, attributes or {},
force_update, force_update,
context, context,
@ -2298,6 +2296,8 @@ class StateMachine:
breaking changes to this function in the future and it breaking changes to this function in the future and it
should not be used in integrations. should not be used in integrations.
Callers are responsible for ensuring the entity_id is lower case.
This method must be run in the event loop. This method must be run in the event loop.
""" """
# Most cases the key will be in the dict # Most cases the key will be in the dict
@ -2356,6 +2356,16 @@ class StateMachine:
assert old_state is not None assert old_state is not None
attributes = old_state.attributes attributes = old_state.attributes
if not same_state and len(new_state) > MAX_LENGTH_STATE_STATE:
_LOGGER.error(
"State %s for %s is longer than %s, falling back to %s",
new_state,
entity_id,
MAX_LENGTH_STATE_STATE,
STATE_UNKNOWN,
)
new_state = STATE_UNKNOWN
# This is intentionally called with positional only arguments for performance # This is intentionally called with positional only arguments for performance
# reasons # reasons
state = State( state = State(

View File

@ -31,7 +31,6 @@ from homeassistant.const import (
ATTR_SUPPORTED_FEATURES, ATTR_SUPPORTED_FEATURES,
ATTR_UNIT_OF_MEASUREMENT, ATTR_UNIT_OF_MEASUREMENT,
DEVICE_DEFAULT_NAME, DEVICE_DEFAULT_NAME,
MAX_LENGTH_STATE_STATE,
STATE_OFF, STATE_OFF,
STATE_ON, STATE_ON,
STATE_UNAVAILABLE, STATE_UNAVAILABLE,
@ -1217,16 +1216,6 @@ class Entity(
self._context = None self._context = None
self._context_set = None self._context_set = None
if len(state) > MAX_LENGTH_STATE_STATE:
_LOGGER.error(
"State %s for %s is longer than %s, falling back to %s",
state,
self.entity_id,
MAX_LENGTH_STATE_STATE,
STATE_UNKNOWN,
)
state = STATE_UNKNOWN
# Intentionally called with positional args for performance reasons # Intentionally called with positional args for performance reasons
self.hass.states.async_set_internal( self.hass.states.async_set_internal(
self.entity_id, self.entity_id,

View File

@ -1711,7 +1711,7 @@ async def test_invalid_state(
ent.async_write_ha_state() ent.async_write_ha_state()
assert hass.states.get("test.test").state == STATE_UNKNOWN assert hass.states.get("test.test").state == STATE_UNKNOWN
assert ( assert (
"homeassistant.helpers.entity", "homeassistant.core",
logging.ERROR, logging.ERROR,
f"State {long_state} for test.test is longer than 255, " f"State {long_state} for test.test is longer than 255, "
f"falling back to {STATE_UNKNOWN}", f"falling back to {STATE_UNKNOWN}",

View File

@ -35,6 +35,7 @@ from homeassistant.const import (
EVENT_STATE_CHANGED, EVENT_STATE_CHANGED,
EVENT_STATE_REPORTED, EVENT_STATE_REPORTED,
MATCH_ALL, MATCH_ALL,
STATE_UNKNOWN,
) )
from homeassistant.core import ( from homeassistant.core import (
CoreState, CoreState,
@ -1368,9 +1369,6 @@ def test_state_init() -> None:
with pytest.raises(InvalidEntityFormatError): with pytest.raises(InvalidEntityFormatError):
ha.State("invalid_entity_format", "test_state") ha.State("invalid_entity_format", "test_state")
with pytest.raises(InvalidStateError):
ha.State("domain.long_state", "t" * 256)
def test_state_domain() -> None: def test_state_domain() -> None:
"""Test domain.""" """Test domain."""
@ -1440,6 +1438,38 @@ def test_state_repr() -> None:
) )
async def test_statemachine_async_set_invalid_state(hass: HomeAssistant) -> None:
"""Test setting an invalid state with the async_set method."""
with pytest.raises(
InvalidStateError,
match="Invalid state with length 256. State max length is 255 characters.",
):
hass.states.async_set("light.bowl", "o" * 256, {})
async def test_statemachine_async_set_internal_invalid_state(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test setting an invalid state with the async_set_internal method."""
long_state = "o" * 256
hass.states.async_set_internal(
"light.bowl",
long_state,
{},
force_update=False,
context=None,
state_info=None,
timestamp=time.time(),
)
assert hass.states.get("light.bowl").state == STATE_UNKNOWN
assert (
"homeassistant.core",
logging.ERROR,
f"State {long_state} for light.bowl is longer than 255, "
f"falling back to {STATE_UNKNOWN}",
) in caplog.record_tuples
async def test_statemachine_is_state(hass: HomeAssistant) -> None: async def test_statemachine_is_state(hass: HomeAssistant) -> None:
"""Test is_state method.""" """Test is_state method."""
hass.states.async_set("light.bowl", "on", {}) hass.states.async_set("light.bowl", "on", {})