Don't allow partial update of counter settings (#78371)

This commit is contained in:
Erik Montnemery 2022-09-13 20:55:06 +02:00 committed by GitHub
parent 15f104911a
commit 47da1c456b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 35 deletions

View File

@ -47,7 +47,7 @@ SERVICE_CONFIGURE = "configure"
STORAGE_KEY = DOMAIN STORAGE_KEY = DOMAIN
STORAGE_VERSION = 1 STORAGE_VERSION = 1
CREATE_FIELDS = { STORAGE_FIELDS = {
vol.Optional(CONF_ICON): cv.icon, vol.Optional(CONF_ICON): cv.icon,
vol.Optional(CONF_INITIAL, default=DEFAULT_INITIAL): cv.positive_int, vol.Optional(CONF_INITIAL, default=DEFAULT_INITIAL): cv.positive_int,
vol.Required(CONF_NAME): vol.All(cv.string, vol.Length(min=1)), vol.Required(CONF_NAME): vol.All(cv.string, vol.Length(min=1)),
@ -57,16 +57,6 @@ CREATE_FIELDS = {
vol.Optional(CONF_STEP, default=DEFAULT_STEP): cv.positive_int, vol.Optional(CONF_STEP, default=DEFAULT_STEP): cv.positive_int,
} }
UPDATE_FIELDS = {
vol.Optional(CONF_ICON): cv.icon,
vol.Optional(CONF_INITIAL): cv.positive_int,
vol.Optional(CONF_NAME): cv.string,
vol.Optional(CONF_MAXIMUM): vol.Any(None, vol.Coerce(int)),
vol.Optional(CONF_MINIMUM): vol.Any(None, vol.Coerce(int)),
vol.Optional(CONF_RESTORE): cv.boolean,
vol.Optional(CONF_STEP): cv.positive_int,
}
def _none_to_empty_dict(value): def _none_to_empty_dict(value):
if value is None: if value is None:
@ -128,7 +118,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
await storage_collection.async_load() await storage_collection.async_load()
collection.StorageCollectionWebsocket( collection.StorageCollectionWebsocket(
storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS
).async_setup(hass) ).async_setup(hass)
component.async_register_entity_service(SERVICE_INCREMENT, {}, "async_increment") component.async_register_entity_service(SERVICE_INCREMENT, {}, "async_increment")
@ -152,12 +142,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
class CounterStorageCollection(collection.StorageCollection): class CounterStorageCollection(collection.StorageCollection):
"""Input storage based collection.""" """Input storage based collection."""
CREATE_SCHEMA = vol.Schema(CREATE_FIELDS) CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)
UPDATE_SCHEMA = vol.Schema(UPDATE_FIELDS)
async def _process_create_data(self, data: dict) -> dict: async def _process_create_data(self, data: dict) -> dict:
"""Validate the config is valid.""" """Validate the config is valid."""
return self.CREATE_SCHEMA(data) return self.CREATE_UPDATE_SCHEMA(data)
@callback @callback
def _get_suggested_id(self, info: dict) -> str: def _get_suggested_id(self, info: dict) -> str:
@ -166,8 +155,8 @@ class CounterStorageCollection(collection.StorageCollection):
async def _update_data(self, data: dict, update_data: dict) -> dict: async def _update_data(self, data: dict, update_data: dict) -> dict:
"""Return a new updated data object.""" """Return a new updated data object."""
update_data = self.UPDATE_SCHEMA(update_data) update_data = self.CREATE_UPDATE_SCHEMA(update_data)
return {**data, **update_data} return {CONF_ID: data[CONF_ID]} | update_data
class Counter(collection.CollectionEntity, RestoreEntity): class Counter(collection.CollectionEntity, RestoreEntity):

View File

@ -591,17 +591,15 @@ async def test_ws_delete(hass, hass_ws_client, storage_setup):
async def test_update_min_max(hass, hass_ws_client, storage_setup): async def test_update_min_max(hass, hass_ws_client, storage_setup):
"""Test updating min/max updates the state.""" """Test updating min/max updates the state."""
items = [ settings = {
{ "initial": 15,
"id": "from_storage", "name": "from storage",
"initial": 15, "maximum": 100,
"name": "from storage", "minimum": 10,
"maximum": 100, "step": 3,
"minimum": 10, "restore": True,
"step": 3, }
"restore": True, items = [{"id": "from_storage"} | settings]
}
]
assert await storage_setup(items) assert await storage_setup(items)
input_id = "from_storage" input_id = "from_storage"
@ -618,16 +616,18 @@ async def test_update_min_max(hass, hass_ws_client, storage_setup):
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
updated_settings = settings | {"minimum": 19}
await client.send_json( await client.send_json(
{ {
"id": 6, "id": 6,
"type": f"{DOMAIN}/update", "type": f"{DOMAIN}/update",
f"{DOMAIN}_id": f"{input_id}", f"{DOMAIN}_id": f"{input_id}",
"minimum": 19, **updated_settings,
} }
) )
resp = await client.receive_json() resp = await client.receive_json()
assert resp["success"] assert resp["success"]
assert resp["result"] == {"id": "from_storage"} | updated_settings
state = hass.states.get(input_entity_id) state = hass.states.get(input_entity_id)
assert int(state.state) == 19 assert int(state.state) == 19
@ -635,18 +635,18 @@ async def test_update_min_max(hass, hass_ws_client, storage_setup):
assert state.attributes[ATTR_MAXIMUM] == 100 assert state.attributes[ATTR_MAXIMUM] == 100
assert state.attributes[ATTR_STEP] == 3 assert state.attributes[ATTR_STEP] == 3
updated_settings = settings | {"maximum": 5, "minimum": 2, "step": 5}
await client.send_json( await client.send_json(
{ {
"id": 7, "id": 7,
"type": f"{DOMAIN}/update", "type": f"{DOMAIN}/update",
f"{DOMAIN}_id": f"{input_id}", f"{DOMAIN}_id": f"{input_id}",
"maximum": 5, **updated_settings,
"minimum": 2,
"step": 5,
} }
) )
resp = await client.receive_json() resp = await client.receive_json()
assert resp["success"] assert resp["success"]
assert resp["result"] == {"id": "from_storage"} | updated_settings
state = hass.states.get(input_entity_id) state = hass.states.get(input_entity_id)
assert int(state.state) == 5 assert int(state.state) == 5
@ -654,18 +654,18 @@ async def test_update_min_max(hass, hass_ws_client, storage_setup):
assert state.attributes[ATTR_MAXIMUM] == 5 assert state.attributes[ATTR_MAXIMUM] == 5
assert state.attributes[ATTR_STEP] == 5 assert state.attributes[ATTR_STEP] == 5
updated_settings = settings | {"maximum": None, "minimum": None, "step": 6}
await client.send_json( await client.send_json(
{ {
"id": 8, "id": 8,
"type": f"{DOMAIN}/update", "type": f"{DOMAIN}/update",
f"{DOMAIN}_id": f"{input_id}", f"{DOMAIN}_id": f"{input_id}",
"maximum": None, **updated_settings,
"minimum": None,
"step": 6,
} }
) )
resp = await client.receive_json() resp = await client.receive_json()
assert resp["success"] assert resp["success"]
assert resp["result"] == {"id": "from_storage"} | updated_settings
state = hass.states.get(input_entity_id) state = hass.states.get(input_entity_id)
assert int(state.state) == 5 assert int(state.state) == 5