From 224cc779c433414170425e6aed1a0e5b97c0838c Mon Sep 17 00:00:00 2001 From: jan iversen Date: Sun, 16 May 2021 08:40:19 +0200 Subject: [PATCH] Correct Modbus platform cover restore state (#50421) * Correct cover restore state. * Change mock usage. * Add states to convert. --- homeassistant/components/modbus/cover.py | 16 ++++++- tests/components/modbus/test_modbus_cover.py | 44 ++++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/homeassistant/components/modbus/cover.py b/homeassistant/components/modbus/cover.py index edb81ae7eb3..48dc08a18b9 100644 --- a/homeassistant/components/modbus/cover.py +++ b/homeassistant/components/modbus/cover.py @@ -12,6 +12,12 @@ from homeassistant.const import ( CONF_NAME, CONF_SCAN_INTERVAL, CONF_SLAVE, + STATE_CLOSED, + STATE_CLOSING, + STATE_OPEN, + STATE_OPENING, + STATE_UNAVAILABLE, + STATE_UNKNOWN, ) from homeassistant.core import HomeAssistant from homeassistant.helpers.event import async_track_time_interval @@ -105,7 +111,15 @@ class ModbusCover(CoverEntity, RestoreEntity): """Handle entity which will be added.""" state = await self.async_get_last_state() if state: - self._value = state.state + convert = { + STATE_CLOSED: self._state_closed, + STATE_CLOSING: self._state_closing, + STATE_OPENING: self._state_opening, + STATE_OPEN: self._state_open, + STATE_UNAVAILABLE: None, + STATE_UNKNOWN: None, + } + self._value = convert[state.state] async_track_time_interval(self.hass, self.async_update, self._scan_interval) diff --git a/tests/components/modbus/test_modbus_cover.py b/tests/components/modbus/test_modbus_cover.py index 09d23ebf8bd..f30ee79bd52 100644 --- a/tests/components/modbus/test_modbus_cover.py +++ b/tests/components/modbus/test_modbus_cover.py @@ -8,6 +8,11 @@ from homeassistant.components.modbus.const import ( CALL_TYPE_COIL, CALL_TYPE_REGISTER_HOLDING, CONF_REGISTER, + CONF_STATE_CLOSED, + CONF_STATE_CLOSING, + CONF_STATE_OPEN, + CONF_STATE_OPENING, + CONF_STATUS_REGISTER, CONF_STATUS_REGISTER_TYPE, ) from homeassistant.const import ( @@ -16,11 +21,16 @@ from homeassistant.const import ( CONF_SCAN_INTERVAL, CONF_SLAVE, STATE_CLOSED, + STATE_CLOSING, STATE_OPEN, + STATE_OPENING, ) +from homeassistant.core import State from .conftest import ReadResult, base_config_test, base_test, prepare_service_update +from tests.common import mock_restore_cache + @pytest.mark.parametrize( "do_options", @@ -202,3 +212,37 @@ async def test_service_cover_update(hass, mock_pymodbus): "homeassistant", "update_entity", {"entity_id": entity_id}, blocking=True ) assert hass.states.get(entity_id).state == STATE_OPEN + + +@pytest.mark.parametrize( + "state", [STATE_CLOSED, STATE_CLOSING, STATE_OPENING, STATE_OPEN] +) +async def test_restore_state_cover(hass, state): + """Run test for cover restore state.""" + + entity_id = "cover.test" + cover_name = "test" + config = { + CONF_NAME: cover_name, + CALL_TYPE_COIL: 1234, + CONF_STATE_OPEN: 1, + CONF_STATE_CLOSED: 0, + CONF_STATE_OPENING: 2, + CONF_STATE_CLOSING: 3, + CONF_STATUS_REGISTER: 1234, + CONF_STATUS_REGISTER_TYPE: CALL_TYPE_REGISTER_HOLDING, + } + mock_restore_cache( + hass, + (State(f"{entity_id}", state),), + ) + await base_config_test( + hass, + config, + cover_name, + COVER_DOMAIN, + CONF_COVERS, + None, + method_discovery=True, + ) + assert hass.states.get(entity_id).state == state