mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
Protect state.as_dict from mutation (#65693)
This commit is contained in:
parent
0d3bbfc9a7
commit
5da923c341
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import Any, TypeVar, cast
|
||||
from typing import Any, TypeVar, cast, overload
|
||||
|
||||
from homeassistant.core import callback
|
||||
|
||||
@ -11,6 +11,16 @@ from .const import REDACTED
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@overload
|
||||
def async_redact_data(data: Mapping, to_redact: Iterable[Any]) -> dict: # type: ignore
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def async_redact_data(data: T, to_redact: Iterable[Any]) -> T:
|
||||
...
|
||||
|
||||
|
||||
@callback
|
||||
def async_redact_data(data: T, to_redact: Iterable[Any]) -> T:
|
||||
"""Redact sensitive data in a dict."""
|
||||
@ -25,7 +35,7 @@ def async_redact_data(data: T, to_redact: Iterable[Any]) -> T:
|
||||
for key, value in redacted.items():
|
||||
if key in to_redact:
|
||||
redacted[key] = REDACTED
|
||||
elif isinstance(value, dict):
|
||||
elif isinstance(value, Mapping):
|
||||
redacted[key] = async_redact_data(value, to_redact)
|
||||
elif isinstance(value, list):
|
||||
redacted[key] = [async_redact_data(item, to_redact) for item in value]
|
||||
|
@ -457,7 +457,7 @@ async def _register_service(
|
||||
}
|
||||
|
||||
async def execute_service(call: ServiceCall) -> None:
|
||||
await entry_data.client.execute_service(service, call.data) # type: ignore[arg-type]
|
||||
await entry_data.client.execute_service(service, call.data)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN, service_name, execute_service, vol.Schema(schema)
|
||||
|
@ -2,9 +2,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterable, Mapping
|
||||
import logging
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.const import (
|
||||
@ -112,8 +111,6 @@ async def async_reproduce_states(
|
||||
)
|
||||
|
||||
|
||||
def check_attr_equal(
|
||||
attr1: MappingProxyType, attr2: MappingProxyType, attr_str: str
|
||||
) -> bool:
|
||||
def check_attr_equal(attr1: Mapping, attr2: Mapping, attr_str: str) -> bool:
|
||||
"""Return true if the given attributes are equal."""
|
||||
return attr1.get(attr_str) == attr2.get(attr_str)
|
||||
|
@ -2,9 +2,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterable, Mapping
|
||||
import logging
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.const import ATTR_ENTITY_ID, ATTR_OPTION
|
||||
@ -80,8 +79,6 @@ async def async_reproduce_states(
|
||||
)
|
||||
|
||||
|
||||
def check_attr_equal(
|
||||
attr1: MappingProxyType, attr2: MappingProxyType, attr_str: str
|
||||
) -> bool:
|
||||
def check_attr_equal(attr1: Mapping, attr2: Mapping, attr_str: str) -> bool:
|
||||
"""Return true if the given attributes are equal."""
|
||||
return attr1.get(attr_str) == attr2.get(attr_str)
|
||||
|
@ -546,7 +546,7 @@ class KNXModule:
|
||||
replaced_exposure.device.name,
|
||||
)
|
||||
replaced_exposure.shutdown()
|
||||
exposure = create_knx_exposure(self.hass, self.xknx, call.data) # type: ignore[arg-type]
|
||||
exposure = create_knx_exposure(self.hass, self.xknx, call.data)
|
||||
self.service_exposures[group_address] = exposure
|
||||
_LOGGER.debug(
|
||||
"Service exposure_register registered exposure for '%s' - %s",
|
||||
|
@ -2,9 +2,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterable, Mapping
|
||||
import logging
|
||||
from types import MappingProxyType
|
||||
from typing import Any, NamedTuple, cast
|
||||
|
||||
from homeassistant.const import (
|
||||
@ -213,8 +212,6 @@ async def async_reproduce_states(
|
||||
)
|
||||
|
||||
|
||||
def check_attr_equal(
|
||||
attr1: MappingProxyType, attr2: MappingProxyType, attr_str: str
|
||||
) -> bool:
|
||||
def check_attr_equal(attr1: Mapping, attr2: Mapping, attr_str: str) -> bool:
|
||||
"""Return true if the given attributes are equal."""
|
||||
return attr1.get(attr_str) == attr2.get(attr_str)
|
||||
|
@ -1,9 +1,9 @@
|
||||
"""Support for Renault services."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from types import MappingProxyType
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import voluptuous as vol
|
||||
@ -126,7 +126,7 @@ def setup_services(hass: HomeAssistant) -> None:
|
||||
result = await proxy.vehicle.set_charge_start()
|
||||
LOGGER.debug("Charge start result: %s", result)
|
||||
|
||||
def get_vehicle_proxy(service_call_data: MappingProxyType) -> RenaultVehicleProxy:
|
||||
def get_vehicle_proxy(service_call_data: Mapping) -> RenaultVehicleProxy:
|
||||
"""Get vehicle from service_call data."""
|
||||
device_registry = dr.async_get(hass)
|
||||
device_id = service_call_data[ATTR_VEHICLE]
|
||||
|
@ -2,8 +2,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Mapping
|
||||
import logging
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Final, cast
|
||||
|
||||
from aioshelly.block_device import Block
|
||||
@ -140,7 +140,7 @@ class BlockSleepingClimate(
|
||||
self.control_result: dict[str, Any] | None = None
|
||||
self.device_block: Block | None = device_block
|
||||
self.last_state: State | None = None
|
||||
self.last_state_attributes: MappingProxyType[str, Any]
|
||||
self.last_state_attributes: Mapping[str, Any]
|
||||
self._preset_modes: list[str] = []
|
||||
|
||||
if self.block is not None and self.device_block is not None:
|
||||
|
@ -24,7 +24,6 @@ import pathlib
|
||||
import re
|
||||
import threading
|
||||
from time import monotonic
|
||||
from types import MappingProxyType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -83,6 +82,7 @@ from .util.async_ import (
|
||||
run_callback_threadsafe,
|
||||
shutdown_run_callback_threadsafe,
|
||||
)
|
||||
from .util.read_only_dict import ReadOnlyDict
|
||||
from .util.timeout import TimeoutManager
|
||||
from .util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM, UnitSystem
|
||||
|
||||
@ -1049,12 +1049,12 @@ class State:
|
||||
|
||||
self.entity_id = entity_id.lower()
|
||||
self.state = state
|
||||
self.attributes = MappingProxyType(attributes or {})
|
||||
self.attributes = ReadOnlyDict(attributes or {})
|
||||
self.last_updated = last_updated or dt_util.utcnow()
|
||||
self.last_changed = last_changed or self.last_updated
|
||||
self.context = context or Context()
|
||||
self.domain, self.object_id = split_entity_id(self.entity_id)
|
||||
self._as_dict: dict[str, Collection[Any]] | None = None
|
||||
self._as_dict: ReadOnlyDict[str, Collection[Any]] | None = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@ -1063,7 +1063,7 @@ class State:
|
||||
"_", " "
|
||||
)
|
||||
|
||||
def as_dict(self) -> dict[str, Collection[Any]]:
|
||||
def as_dict(self) -> ReadOnlyDict[str, Collection[Any]]:
|
||||
"""Return a dict representation of the State.
|
||||
|
||||
Async friendly.
|
||||
@ -1077,14 +1077,16 @@ class State:
|
||||
last_updated_isoformat = last_changed_isoformat
|
||||
else:
|
||||
last_updated_isoformat = self.last_updated.isoformat()
|
||||
self._as_dict = {
|
||||
"entity_id": self.entity_id,
|
||||
"state": self.state,
|
||||
"attributes": dict(self.attributes),
|
||||
"last_changed": last_changed_isoformat,
|
||||
"last_updated": last_updated_isoformat,
|
||||
"context": self.context.as_dict(),
|
||||
}
|
||||
self._as_dict = ReadOnlyDict(
|
||||
{
|
||||
"entity_id": self.entity_id,
|
||||
"state": self.state,
|
||||
"attributes": self.attributes,
|
||||
"last_changed": last_changed_isoformat,
|
||||
"last_updated": last_updated_isoformat,
|
||||
"context": ReadOnlyDict(self.context.as_dict()),
|
||||
}
|
||||
)
|
||||
return self._as_dict
|
||||
|
||||
@classmethod
|
||||
@ -1343,7 +1345,7 @@ class StateMachine:
|
||||
last_changed = None
|
||||
else:
|
||||
same_state = old_state.state == new_state and not force_update
|
||||
same_attr = old_state.attributes == MappingProxyType(attributes)
|
||||
same_attr = old_state.attributes == attributes
|
||||
last_changed = old_state.last_changed if same_state else None
|
||||
|
||||
if same_state and same_attr:
|
||||
@ -1404,7 +1406,7 @@ class ServiceCall:
|
||||
"""Initialize a service call."""
|
||||
self.domain = domain.lower()
|
||||
self.service = service.lower()
|
||||
self.data = MappingProxyType(data or {})
|
||||
self.data = ReadOnlyDict(data or {})
|
||||
self.context = context or Context()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
@ -2,14 +2,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine, Iterable, KeysView
|
||||
from collections.abc import Callable, Coroutine, Iterable, KeysView, Mapping
|
||||
from datetime import datetime, timedelta
|
||||
from functools import wraps
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import threading
|
||||
from types import MappingProxyType
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import slugify as unicode_slug
|
||||
@ -53,7 +52,7 @@ def slugify(text: str | None, *, separator: str = "_") -> str:
|
||||
|
||||
def repr_helper(inp: Any) -> str:
|
||||
"""Help creating a more readable string representation of objects."""
|
||||
if isinstance(inp, (dict, MappingProxyType)):
|
||||
if isinstance(inp, Mapping):
|
||||
return ", ".join(
|
||||
f"{repr_helper(key)}={repr_helper(item)}" for key, item in inp.items()
|
||||
)
|
||||
|
23
homeassistant/util/read_only_dict.py
Normal file
23
homeassistant/util/read_only_dict.py
Normal file
@ -0,0 +1,23 @@
|
||||
"""Read only dictionary."""
|
||||
from typing import Any, TypeVar
|
||||
|
||||
|
||||
def _readonly(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Raise an exception when a read only dict is modified."""
|
||||
raise RuntimeError("Cannot modify ReadOnlyDict")
|
||||
|
||||
|
||||
Key = TypeVar("Key")
|
||||
Value = TypeVar("Value")
|
||||
|
||||
|
||||
class ReadOnlyDict(dict[Key, Value]):
|
||||
"""Read only version of dict that is compatible with dict types."""
|
||||
|
||||
__setitem__ = _readonly
|
||||
__delitem__ = _readonly
|
||||
pop = _readonly
|
||||
popitem = _readonly
|
||||
clear = _readonly
|
||||
update = _readonly
|
||||
setdefault = _readonly
|
@ -931,9 +931,12 @@ def mock_restore_cache(hass, states):
|
||||
last_states = {}
|
||||
for state in states:
|
||||
restored_state = state.as_dict()
|
||||
restored_state["attributes"] = json.loads(
|
||||
json.dumps(restored_state["attributes"], cls=JSONEncoder)
|
||||
)
|
||||
restored_state = {
|
||||
**restored_state,
|
||||
"attributes": json.loads(
|
||||
json.dumps(restored_state["attributes"], cls=JSONEncoder)
|
||||
),
|
||||
}
|
||||
last_states[state.entity_id] = restore_state.StoredState(
|
||||
State.from_dict(restored_state), now
|
||||
)
|
||||
|
@ -39,6 +39,7 @@ from homeassistant.exceptions import (
|
||||
ServiceNotFound,
|
||||
)
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.util.read_only_dict import ReadOnlyDict
|
||||
from homeassistant.util.unit_system import METRIC_SYSTEM
|
||||
|
||||
from tests.common import async_capture_events, async_mock_service
|
||||
@ -377,10 +378,14 @@ def test_state_as_dict():
|
||||
"last_updated": last_time.isoformat(),
|
||||
"state": "on",
|
||||
}
|
||||
assert state.as_dict() == expected
|
||||
as_dict_1 = state.as_dict()
|
||||
assert isinstance(as_dict_1, ReadOnlyDict)
|
||||
assert isinstance(as_dict_1["attributes"], ReadOnlyDict)
|
||||
assert isinstance(as_dict_1["context"], ReadOnlyDict)
|
||||
assert as_dict_1 == expected
|
||||
# 2nd time to verify cache
|
||||
assert state.as_dict() == expected
|
||||
assert state.as_dict() is state.as_dict()
|
||||
assert state.as_dict() is as_dict_1
|
||||
|
||||
|
||||
async def test_eventbus_add_remove_listener(hass):
|
||||
|
36
tests/util/test_read_only_dict.py
Normal file
36
tests/util/test_read_only_dict.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""Test read only dictionary."""
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.util.read_only_dict import ReadOnlyDict
|
||||
|
||||
|
||||
def test_read_only_dict():
|
||||
"""Test read only dictionary."""
|
||||
data = ReadOnlyDict({"hello": "world"})
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
data["hello"] = "universe"
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
data["other_key"] = "universe"
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
data.pop("hello")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
data.popitem()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
data.clear()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
data.update({"yo": "yo"})
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
data.setdefault("yo", "yo")
|
||||
|
||||
assert isinstance(data, dict)
|
||||
assert dict(data) == {"hello": "world"}
|
||||
assert json.dumps(data) == json.dumps({"hello": "world"})
|
Loading…
x
Reference in New Issue
Block a user