diff --git a/homeassistant/components/modbus/binary_sensor.py b/homeassistant/components/modbus/binary_sensor.py index 045447f7246..c27fde6d946 100644 --- a/homeassistant/components/modbus/binary_sensor.py +++ b/homeassistant/components/modbus/binary_sensor.py @@ -17,9 +17,11 @@ from homeassistant.const import ( CONF_NAME, CONF_SCAN_INTERVAL, CONF_SLAVE, + STATE_ON, ) from homeassistant.core import HomeAssistant from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from .base_platform import BasePlatform @@ -96,12 +98,17 @@ async def async_setup_platform( async_add_entities(sensors) -class ModbusBinarySensor(BasePlatform, BinarySensorEntity): +class ModbusBinarySensor(BasePlatform, RestoreEntity, BinarySensorEntity): """Modbus binary sensor.""" async def async_added_to_hass(self): """Handle entity which will be added.""" await self.async_base_added_to_hass() + state = await self.async_get_last_state() + if state: + self._value = state.state == STATE_ON + else: + self._value = None @property def is_on(self): diff --git a/tests/components/modbus/test_modbus_binary_sensor.py b/tests/components/modbus/test_binary_sensor.py similarity index 81% rename from tests/components/modbus/test_modbus_binary_sensor.py rename to tests/components/modbus/test_binary_sensor.py index 27821c170e1..5089d0271dd 100644 --- a/tests/components/modbus/test_modbus_binary_sensor.py +++ b/tests/components/modbus/test_binary_sensor.py @@ -18,9 +18,12 @@ from homeassistant.const import ( STATE_ON, STATE_UNAVAILABLE, ) +from homeassistant.core import State from .conftest import ReadResult, base_config_test, base_test, prepare_service_update +from tests.common import mock_restore_cache + @pytest.mark.parametrize("do_discovery", [False, True]) @pytest.mark.parametrize( @@ -130,3 +133,26 @@ async def test_service_binary_sensor_update(hass, mock_pymodbus): "homeassistant", "update_entity", {"entity_id": entity_id}, blocking=True ) assert hass.states.get(entity_id).state == STATE_ON + + +async def test_restore_state_binary_sensor(hass): + """Run test for binary sensor restore state.""" + + sensor_name = "test_binary_sensor" + test_value = STATE_ON + config_sensor = {CONF_NAME: sensor_name, CONF_ADDRESS: 17} + mock_restore_cache( + hass, + (State(f"{SENSOR_DOMAIN}.{sensor_name}", test_value),), + ) + await base_config_test( + hass, + config_sensor, + sensor_name, + SENSOR_DOMAIN, + CONF_BINARY_SENSORS, + None, + method_discovery=True, + ) + entity_id = f"{SENSOR_DOMAIN}.{sensor_name}" + assert hass.states.get(entity_id).state == test_value