diff --git a/homeassistant/components/input_number/__init__.py b/homeassistant/components/input_number/__init__.py index affad6ca30f..3a7f7b29f13 100644 --- a/homeassistant/components/input_number/__init__.py +++ b/homeassistant/components/input_number/__init__.py @@ -65,7 +65,7 @@ def _cv_input_number(cfg): return cfg -CREATE_FIELDS = { +STORAGE_FIELDS = { vol.Required(CONF_NAME): vol.All(str, vol.Length(min=1)), vol.Required(CONF_MIN): vol.Coerce(float), vol.Required(CONF_MAX): vol.Coerce(float), @@ -76,17 +76,6 @@ CREATE_FIELDS = { vol.Optional(CONF_MODE, default=MODE_SLIDER): vol.In([MODE_BOX, MODE_SLIDER]), } -UPDATE_FIELDS = { - vol.Optional(CONF_NAME): cv.string, - vol.Optional(CONF_MIN): vol.Coerce(float), - vol.Optional(CONF_MAX): vol.Coerce(float), - vol.Optional(CONF_INITIAL): vol.Coerce(float), - vol.Optional(CONF_STEP): vol.All(vol.Coerce(float), vol.Range(min=1e-9)), - vol.Optional(CONF_ICON): cv.icon, - vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string, - vol.Optional(CONF_MODE): vol.In([MODE_BOX, MODE_SLIDER]), -} - CONFIG_SCHEMA = vol.Schema( { DOMAIN: cv.schema_with_slug_keys( @@ -148,7 +137,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: await storage_collection.async_load() collection.StorageCollectionWebsocket( - storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS + storage_collection, DOMAIN, DOMAIN, STORAGE_FIELDS, STORAGE_FIELDS ).async_setup(hass) async def reload_service_handler(service_call: ServiceCall) -> None: @@ -184,12 +173,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: class NumberStorageCollection(collection.StorageCollection): """Input storage based collection.""" - CREATE_SCHEMA = vol.Schema(vol.All(CREATE_FIELDS, _cv_input_number)) - UPDATE_SCHEMA = vol.Schema(UPDATE_FIELDS) + SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_number)) async def _process_create_data(self, data: dict) -> dict: """Validate the config is valid.""" - return self.CREATE_SCHEMA(data) + return self.SCHEMA(data) @callback def _get_suggested_id(self, info: dict) -> str: @@ -214,8 +202,8 @@ class NumberStorageCollection(collection.StorageCollection): async def _update_data(self, data: dict, update_data: dict) -> dict: """Return a new updated data object.""" - update_data = self.UPDATE_SCHEMA(update_data) - return _cv_input_number({**data, **update_data}) + update_data = self.SCHEMA(update_data) + return {CONF_ID: data[CONF_ID]} | update_data class InputNumber(collection.CollectionEntity, RestoreEntity): diff --git a/tests/components/input_number/test_init.py b/tests/components/input_number/test_init.py index bec05d3f344..7ba7489f644 100644 --- a/tests/components/input_number/test_init.py +++ b/tests/components/input_number/test_init.py @@ -507,16 +507,14 @@ async def test_ws_delete(hass, hass_ws_client, storage_setup): async def test_update_min_max(hass, hass_ws_client, storage_setup): """Test updating min/max updates the state.""" - items = [ - { - "id": "from_storage", - "name": "from storage", - "max": 100, - "min": 0, - "step": 1, - "mode": "slider", - } - ] + settings = { + "name": "from storage", + "max": 100, + "min": 0, + "step": 1, + "mode": "slider", + } + items = [{"id": "from_storage"} | settings] assert await storage_setup(items) input_id = "from_storage" @@ -530,26 +528,34 @@ async def test_update_min_max(hass, hass_ws_client, storage_setup): client = await hass_ws_client(hass) + updated_settings = settings | {"min": 9} await client.send_json( - {"id": 6, "type": f"{DOMAIN}/update", f"{DOMAIN}_id": f"{input_id}", "min": 9} + { + "id": 6, + "type": f"{DOMAIN}/update", + f"{DOMAIN}_id": f"{input_id}", + **updated_settings, + } ) resp = await client.receive_json() assert resp["success"] + assert resp["result"] == {"id": "from_storage"} | updated_settings state = hass.states.get(input_entity_id) assert float(state.state) == 9 + updated_settings = settings | {"max": 5} await client.send_json( { "id": 7, "type": f"{DOMAIN}/update", f"{DOMAIN}_id": f"{input_id}", - "max": 5, - "min": 0, + **updated_settings, } ) resp = await client.receive_json() assert resp["success"] + assert resp["result"] == {"id": "from_storage"} | updated_settings state = hass.states.get(input_entity_id) assert float(state.state) == 5