From 99d1e3e71d80fcfcf395536c2d6c1df1c025e423 Mon Sep 17 00:00:00 2001 From: Drzony Date: Mon, 15 Mar 2021 11:24:07 +0100 Subject: [PATCH] MQTT Light: Use flash attribute in async_turn_off (#47919) --- .../components/mqtt/light/schema_json.py | 26 ++++++++++--------- tests/components/light/common.py | 12 ++++++--- tests/components/mqtt/test_light_json.py | 18 +++++++++++++ 3 files changed, 40 insertions(+), 16 deletions(-) diff --git a/homeassistant/components/mqtt/light/schema_json.py b/homeassistant/components/mqtt/light/schema_json.py index 99c48aa1c8f..8f2d1dda0a7 100644 --- a/homeassistant/components/mqtt/light/schema_json.py +++ b/homeassistant/components/mqtt/light/schema_json.py @@ -341,6 +341,18 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity): """Flag supported features.""" return self._supported_features + def _set_flash_and_transition(self, message, **kwargs): + if ATTR_TRANSITION in kwargs: + message["transition"] = kwargs[ATTR_TRANSITION] + + if ATTR_FLASH in kwargs: + flash = kwargs.get(ATTR_FLASH) + + if flash == FLASH_LONG: + message["flash"] = self._flash_times[CONF_FLASH_TIME_LONG] + elif flash == FLASH_SHORT: + message["flash"] = self._flash_times[CONF_FLASH_TIME_SHORT] + async def async_turn_on(self, **kwargs): """Turn the device on. @@ -380,16 +392,7 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity): self._hs = kwargs[ATTR_HS_COLOR] should_update = True - if ATTR_FLASH in kwargs: - flash = kwargs.get(ATTR_FLASH) - - if flash == FLASH_LONG: - message["flash"] = self._flash_times[CONF_FLASH_TIME_LONG] - elif flash == FLASH_SHORT: - message["flash"] = self._flash_times[CONF_FLASH_TIME_SHORT] - - if ATTR_TRANSITION in kwargs: - message["transition"] = kwargs[ATTR_TRANSITION] + self._set_flash_and_transition(message, **kwargs) if ATTR_BRIGHTNESS in kwargs and self._config[CONF_BRIGHTNESS]: brightness_normalized = kwargs[ATTR_BRIGHTNESS] / DEFAULT_BRIGHTNESS_SCALE @@ -449,8 +452,7 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity): """ message = {"state": "OFF"} - if ATTR_TRANSITION in kwargs: - message["transition"] = kwargs[ATTR_TRANSITION] + self._set_flash_and_transition(message, **kwargs) mqtt.async_publish( self.hass, diff --git a/tests/components/light/common.py b/tests/components/light/common.py index a9991bf3594..20ace3641cd 100644 --- a/tests/components/light/common.py +++ b/tests/components/light/common.py @@ -111,16 +111,20 @@ async def async_turn_on( @bind_hass -def turn_off(hass, entity_id=ENTITY_MATCH_ALL, transition=None): +def turn_off(hass, entity_id=ENTITY_MATCH_ALL, transition=None, flash=None): """Turn all or specified light off.""" - hass.add_job(async_turn_off, hass, entity_id, transition) + hass.add_job(async_turn_off, hass, entity_id, transition, flash) -async def async_turn_off(hass, entity_id=ENTITY_MATCH_ALL, transition=None): +async def async_turn_off(hass, entity_id=ENTITY_MATCH_ALL, transition=None, flash=None): """Turn all or specified light off.""" data = { key: value - for key, value in [(ATTR_ENTITY_ID, entity_id), (ATTR_TRANSITION, transition)] + for key, value in [ + (ATTR_ENTITY_ID, entity_id), + (ATTR_TRANSITION, transition), + (ATTR_FLASH, flash), + ] if value is not None } diff --git a/tests/components/mqtt/test_light_json.py b/tests/components/mqtt/test_light_json.py index 022df109f38..bdb81e5e5e4 100644 --- a/tests/components/mqtt/test_light_json.py +++ b/tests/components/mqtt/test_light_json.py @@ -876,6 +876,24 @@ async def test_flash_short_and_long(hass, mqtt_mock): state = hass.states.get("light.test") assert state.state == STATE_ON + await common.async_turn_off(hass, "light.test", flash="short") + + mqtt_mock.async_publish.assert_called_once_with( + "test_light_rgb/set", JsonValidator('{"state": "OFF", "flash": 5}'), 0, False + ) + mqtt_mock.async_publish.reset_mock() + state = hass.states.get("light.test") + assert state.state == STATE_OFF + + await common.async_turn_off(hass, "light.test", flash="long") + + mqtt_mock.async_publish.assert_called_once_with( + "test_light_rgb/set", JsonValidator('{"state": "OFF", "flash": 15}'), 0, False + ) + mqtt_mock.async_publish.reset_mock() + state = hass.states.get("light.test") + assert state.state == STATE_OFF + async def test_transition(hass, mqtt_mock): """Test for transition time being sent when included."""