Protect state.as_dict from mutation (#65693)

This commit is contained in:
Paulus Schoutsen 2022-02-04 14:45:25 -08:00 committed by GitHub
parent 0d3bbfc9a7
commit 5da923c341
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 114 additions and 45 deletions

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from collections.abc import Iterable, Mapping
from typing import Any, TypeVar, cast
from typing import Any, TypeVar, cast, overload
from homeassistant.core import callback
@ -11,6 +11,16 @@ from .const import REDACTED
T = TypeVar("T")
@overload
def async_redact_data(data: Mapping, to_redact: Iterable[Any]) -> dict: # type: ignore
...
@overload
def async_redact_data(data: T, to_redact: Iterable[Any]) -> T:
...
@callback
def async_redact_data(data: T, to_redact: Iterable[Any]) -> T:
"""Redact sensitive data in a dict."""
@ -25,7 +35,7 @@ def async_redact_data(data: T, to_redact: Iterable[Any]) -> T:
for key, value in redacted.items():
if key in to_redact:
redacted[key] = REDACTED
elif isinstance(value, dict):
elif isinstance(value, Mapping):
redacted[key] = async_redact_data(value, to_redact)
elif isinstance(value, list):
redacted[key] = [async_redact_data(item, to_redact) for item in value]

View File

@ -457,7 +457,7 @@ async def _register_service(
}
async def execute_service(call: ServiceCall) -> None:
await entry_data.client.execute_service(service, call.data) # type: ignore[arg-type]
await entry_data.client.execute_service(service, call.data)
hass.services.async_register(
DOMAIN, service_name, execute_service, vol.Schema(schema)

View File

@ -2,9 +2,8 @@
from __future__ import annotations
import asyncio
from collections.abc import Iterable
from collections.abc import Iterable, Mapping
import logging
from types import MappingProxyType
from typing import Any
from homeassistant.const import (
@ -112,8 +111,6 @@ async def async_reproduce_states(
)
def check_attr_equal(
attr1: MappingProxyType, attr2: MappingProxyType, attr_str: str
) -> bool:
def check_attr_equal(attr1: Mapping, attr2: Mapping, attr_str: str) -> bool:
"""Return true if the given attributes are equal."""
return attr1.get(attr_str) == attr2.get(attr_str)

View File

@ -2,9 +2,8 @@
from __future__ import annotations
import asyncio
from collections.abc import Iterable
from collections.abc import Iterable, Mapping
import logging
from types import MappingProxyType
from typing import Any
from homeassistant.const import ATTR_ENTITY_ID, ATTR_OPTION
@ -80,8 +79,6 @@ async def async_reproduce_states(
)
def check_attr_equal(
attr1: MappingProxyType, attr2: MappingProxyType, attr_str: str
) -> bool:
def check_attr_equal(attr1: Mapping, attr2: Mapping, attr_str: str) -> bool:
"""Return true if the given attributes are equal."""
return attr1.get(attr_str) == attr2.get(attr_str)

View File

@ -546,7 +546,7 @@ class KNXModule:
replaced_exposure.device.name,
)
replaced_exposure.shutdown()
exposure = create_knx_exposure(self.hass, self.xknx, call.data) # type: ignore[arg-type]
exposure = create_knx_exposure(self.hass, self.xknx, call.data)
self.service_exposures[group_address] = exposure
_LOGGER.debug(
"Service exposure_register registered exposure for '%s' - %s",

View File

@ -2,9 +2,8 @@
from __future__ import annotations
import asyncio
from collections.abc import Iterable
from collections.abc import Iterable, Mapping
import logging
from types import MappingProxyType
from typing import Any, NamedTuple, cast
from homeassistant.const import (
@ -213,8 +212,6 @@ async def async_reproduce_states(
)
def check_attr_equal(
attr1: MappingProxyType, attr2: MappingProxyType, attr_str: str
) -> bool:
def check_attr_equal(attr1: Mapping, attr2: Mapping, attr_str: str) -> bool:
"""Return true if the given attributes are equal."""
return attr1.get(attr_str) == attr2.get(attr_str)

View File

@ -1,9 +1,9 @@
"""Support for Renault services."""
from __future__ import annotations
from collections.abc import Mapping
from datetime import datetime
import logging
from types import MappingProxyType
from typing import TYPE_CHECKING, Any
import voluptuous as vol
@ -126,7 +126,7 @@ def setup_services(hass: HomeAssistant) -> None:
result = await proxy.vehicle.set_charge_start()
LOGGER.debug("Charge start result: %s", result)
def get_vehicle_proxy(service_call_data: MappingProxyType) -> RenaultVehicleProxy:
def get_vehicle_proxy(service_call_data: Mapping) -> RenaultVehicleProxy:
"""Get vehicle from service_call data."""
device_registry = dr.async_get(hass)
device_id = service_call_data[ATTR_VEHICLE]

View File

@ -2,8 +2,8 @@
from __future__ import annotations
import asyncio
from collections.abc import Mapping
import logging
from types import MappingProxyType
from typing import Any, Final, cast
from aioshelly.block_device import Block
@ -140,7 +140,7 @@ class BlockSleepingClimate(
self.control_result: dict[str, Any] | None = None
self.device_block: Block | None = device_block
self.last_state: State | None = None
self.last_state_attributes: MappingProxyType[str, Any]
self.last_state_attributes: Mapping[str, Any]
self._preset_modes: list[str] = []
if self.block is not None and self.device_block is not None:

View File

@ -24,7 +24,6 @@ import pathlib
import re
import threading
from time import monotonic
from types import MappingProxyType
from typing import (
TYPE_CHECKING,
Any,
@ -83,6 +82,7 @@ from .util.async_ import (
run_callback_threadsafe,
shutdown_run_callback_threadsafe,
)
from .util.read_only_dict import ReadOnlyDict
from .util.timeout import TimeoutManager
from .util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM, UnitSystem
@ -1049,12 +1049,12 @@ class State:
self.entity_id = entity_id.lower()
self.state = state
self.attributes = MappingProxyType(attributes or {})
self.attributes = ReadOnlyDict(attributes or {})
self.last_updated = last_updated or dt_util.utcnow()
self.last_changed = last_changed or self.last_updated
self.context = context or Context()
self.domain, self.object_id = split_entity_id(self.entity_id)
self._as_dict: dict[str, Collection[Any]] | None = None
self._as_dict: ReadOnlyDict[str, Collection[Any]] | None = None
@property
def name(self) -> str:
@ -1063,7 +1063,7 @@ class State:
"_", " "
)
def as_dict(self) -> dict[str, Collection[Any]]:
def as_dict(self) -> ReadOnlyDict[str, Collection[Any]]:
"""Return a dict representation of the State.
Async friendly.
@ -1077,14 +1077,16 @@ class State:
last_updated_isoformat = last_changed_isoformat
else:
last_updated_isoformat = self.last_updated.isoformat()
self._as_dict = {
"entity_id": self.entity_id,
"state": self.state,
"attributes": dict(self.attributes),
"last_changed": last_changed_isoformat,
"last_updated": last_updated_isoformat,
"context": self.context.as_dict(),
}
self._as_dict = ReadOnlyDict(
{
"entity_id": self.entity_id,
"state": self.state,
"attributes": self.attributes,
"last_changed": last_changed_isoformat,
"last_updated": last_updated_isoformat,
"context": ReadOnlyDict(self.context.as_dict()),
}
)
return self._as_dict
@classmethod
@ -1343,7 +1345,7 @@ class StateMachine:
last_changed = None
else:
same_state = old_state.state == new_state and not force_update
same_attr = old_state.attributes == MappingProxyType(attributes)
same_attr = old_state.attributes == attributes
last_changed = old_state.last_changed if same_state else None
if same_state and same_attr:
@ -1404,7 +1406,7 @@ class ServiceCall:
"""Initialize a service call."""
self.domain = domain.lower()
self.service = service.lower()
self.data = MappingProxyType(data or {})
self.data = ReadOnlyDict(data or {})
self.context = context or Context()
def __repr__(self) -> str:

View File

@ -2,14 +2,13 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine, Iterable, KeysView
from collections.abc import Callable, Coroutine, Iterable, KeysView, Mapping
from datetime import datetime, timedelta
from functools import wraps
import random
import re
import string
import threading
from types import MappingProxyType
from typing import Any, TypeVar
import slugify as unicode_slug
@ -53,7 +52,7 @@ def slugify(text: str | None, *, separator: str = "_") -> str:
def repr_helper(inp: Any) -> str:
"""Help creating a more readable string representation of objects."""
if isinstance(inp, (dict, MappingProxyType)):
if isinstance(inp, Mapping):
return ", ".join(
f"{repr_helper(key)}={repr_helper(item)}" for key, item in inp.items()
)

View File

@ -0,0 +1,23 @@
"""Read only dictionary."""
from typing import Any, TypeVar
def _readonly(*args: Any, **kwargs: Any) -> Any:
"""Raise an exception when a read only dict is modified."""
raise RuntimeError("Cannot modify ReadOnlyDict")
Key = TypeVar("Key")
Value = TypeVar("Value")
class ReadOnlyDict(dict[Key, Value]):
"""Read only version of dict that is compatible with dict types."""
__setitem__ = _readonly
__delitem__ = _readonly
pop = _readonly
popitem = _readonly
clear = _readonly
update = _readonly
setdefault = _readonly

View File

@ -931,9 +931,12 @@ def mock_restore_cache(hass, states):
last_states = {}
for state in states:
restored_state = state.as_dict()
restored_state["attributes"] = json.loads(
json.dumps(restored_state["attributes"], cls=JSONEncoder)
)
restored_state = {
**restored_state,
"attributes": json.loads(
json.dumps(restored_state["attributes"], cls=JSONEncoder)
),
}
last_states[state.entity_id] = restore_state.StoredState(
State.from_dict(restored_state), now
)

View File

@ -39,6 +39,7 @@ from homeassistant.exceptions import (
ServiceNotFound,
)
import homeassistant.util.dt as dt_util
from homeassistant.util.read_only_dict import ReadOnlyDict
from homeassistant.util.unit_system import METRIC_SYSTEM
from tests.common import async_capture_events, async_mock_service
@ -377,10 +378,14 @@ def test_state_as_dict():
"last_updated": last_time.isoformat(),
"state": "on",
}
assert state.as_dict() == expected
as_dict_1 = state.as_dict()
assert isinstance(as_dict_1, ReadOnlyDict)
assert isinstance(as_dict_1["attributes"], ReadOnlyDict)
assert isinstance(as_dict_1["context"], ReadOnlyDict)
assert as_dict_1 == expected
# 2nd time to verify cache
assert state.as_dict() == expected
assert state.as_dict() is state.as_dict()
assert state.as_dict() is as_dict_1
async def test_eventbus_add_remove_listener(hass):

View File

@ -0,0 +1,36 @@
"""Test read only dictionary."""
import json
import pytest
from homeassistant.util.read_only_dict import ReadOnlyDict
def test_read_only_dict():
"""Test read only dictionary."""
data = ReadOnlyDict({"hello": "world"})
with pytest.raises(RuntimeError):
data["hello"] = "universe"
with pytest.raises(RuntimeError):
data["other_key"] = "universe"
with pytest.raises(RuntimeError):
data.pop("hello")
with pytest.raises(RuntimeError):
data.popitem()
with pytest.raises(RuntimeError):
data.clear()
with pytest.raises(RuntimeError):
data.update({"yo": "yo"})
with pytest.raises(RuntimeError):
data.setdefault("yo", "yo")
assert isinstance(data, dict)
assert dict(data) == {"hello": "world"}
assert json.dumps(data) == json.dumps({"hello": "world"})