diff --git a/homeassistant/components/demo/cover.py b/homeassistant/components/demo/cover.py index 20e3a52aa8d..ab95cc978b3 100644 --- a/homeassistant/components/demo/cover.py +++ b/homeassistant/components/demo/cover.py @@ -6,7 +6,8 @@ from homeassistant.components.cover import ( SUPPORT_OPEN, CoverDevice, ) -from homeassistant.helpers.event import track_utc_time_change +from homeassistant.core import callback +from homeassistant.helpers.event import async_track_utc_time_change from . import DOMAIN @@ -131,21 +132,21 @@ class DemoCover(CoverDevice): return self._supported_features return super().supported_features - def close_cover(self, **kwargs): + async def async_close_cover(self, **kwargs): """Close the cover.""" if self._position == 0: return if self._position is None: self._closed = True - self.schedule_update_ha_state() + self.async_write_ha_state() return self._is_closing = True self._listen_cover() self._requested_closing = True - self.schedule_update_ha_state() + self.async_write_ha_state() - def close_cover_tilt(self, **kwargs): + async def async_close_cover_tilt(self, **kwargs): """Close the cover tilt.""" if self._tilt_position in (0, None): return @@ -153,21 +154,21 @@ class DemoCover(CoverDevice): self._listen_cover_tilt() self._requested_closing_tilt = True - def open_cover(self, **kwargs): + async def async_open_cover(self, **kwargs): """Open the cover.""" if self._position == 100: return if self._position is None: self._closed = False - self.schedule_update_ha_state() + self.async_write_ha_state() return self._is_opening = True self._listen_cover() self._requested_closing = False - self.schedule_update_ha_state() + self.async_write_ha_state() - def open_cover_tilt(self, **kwargs): + async def async_open_cover_tilt(self, **kwargs): """Open the cover tilt.""" if self._tilt_position in (100, None): return @@ -175,7 +176,7 @@ class DemoCover(CoverDevice): self._listen_cover_tilt() self._requested_closing_tilt = False - def set_cover_position(self, **kwargs): + async def async_set_cover_position(self, **kwargs): """Move the cover to a specific position.""" position = kwargs.get(ATTR_POSITION) self._set_position = round(position, -1) @@ -185,7 +186,7 @@ class DemoCover(CoverDevice): self._listen_cover() self._requested_closing = position < self._position - def set_cover_tilt_position(self, **kwargs): + async def async_set_cover_tilt_position(self, **kwargs): """Move the cover til to a specific position.""" tilt_position = kwargs.get(ATTR_TILT_POSITION) self._set_tilt_position = round(tilt_position, -1) @@ -195,7 +196,7 @@ class DemoCover(CoverDevice): self._listen_cover_tilt() self._requested_closing_tilt = tilt_position < self._tilt_position - def stop_cover(self, **kwargs): + async def async_stop_cover(self, **kwargs): """Stop the cover.""" self._is_closing = False self._is_opening = False @@ -206,7 +207,7 @@ class DemoCover(CoverDevice): self._unsub_listener_cover = None self._set_position = None - def stop_cover_tilt(self, **kwargs): + async def async_stop_cover_tilt(self, **kwargs): """Stop the cover tilt.""" if self._tilt_position is None: return @@ -216,14 +217,15 @@ class DemoCover(CoverDevice): self._unsub_listener_cover_tilt = None self._set_tilt_position = None + @callback def _listen_cover(self): """Listen for changes in cover.""" if self._unsub_listener_cover is None: - self._unsub_listener_cover = track_utc_time_change( + self._unsub_listener_cover = async_track_utc_time_change( self.hass, self._time_changed_cover ) - def _time_changed_cover(self, now): + async def _time_changed_cover(self, now): """Track time changes.""" if self._requested_closing: self._position -= 10 @@ -231,20 +233,20 @@ class DemoCover(CoverDevice): self._position += 10 if self._position in (100, 0, self._set_position): - self.stop_cover() + await self.async_stop_cover() self._closed = self.current_cover_position <= 0 + self.async_write_ha_state() - self.schedule_update_ha_state() - + @callback def _listen_cover_tilt(self): """Listen for changes in cover tilt.""" if self._unsub_listener_cover_tilt is None: - self._unsub_listener_cover_tilt = track_utc_time_change( + self._unsub_listener_cover_tilt = async_track_utc_time_change( self.hass, self._time_changed_cover_tilt ) - def _time_changed_cover_tilt(self, now): + async def _time_changed_cover_tilt(self, now): """Track time changes.""" if self._requested_closing_tilt: self._tilt_position -= 10 @@ -252,6 +254,6 @@ class DemoCover(CoverDevice): self._tilt_position += 10 if self._tilt_position in (100, 0, self._set_tilt_position): - self.stop_cover_tilt() + await self.async_stop_cover_tilt() - self.schedule_update_ha_state() + self.async_write_ha_state() diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index d621d4e6242..89c2715a760 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -1,6 +1,6 @@ """Service calling related helpers.""" import asyncio -from functools import wraps +from functools import partial, wraps import logging from typing import Callable @@ -339,7 +339,7 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non tasks = [ _handle_service_platform_call( - func, data, entities, call.context, required_features + hass, func, data, entities, call.context, required_features ) for platform, entities in zip(platforms, platforms_entities) ] @@ -352,7 +352,7 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non async def _handle_service_platform_call( - func, data, entities, context, required_features + hass, func, data, entities, context, required_features ): """Handle a function call.""" tasks = [] @@ -370,9 +370,17 @@ async def _handle_service_platform_call( entity.async_set_context(context) if isinstance(func, str): - await getattr(entity, func)(**data) + result = await hass.async_add_job(partial(getattr(entity, func), **data)) else: - await func(entity, data) + result = await hass.async_add_job(func, entity, data) + + if asyncio.iscoroutine(result): + _LOGGER.error( + "Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to component author.", + func, + entity.entity_id, + ) + await result if entity.should_poll: tasks.append(entity.async_update_ha_state(True)) diff --git a/tests/components/rflink/test_light.py b/tests/components/rflink/test_light.py index b22730a3310..970c532f22e 100644 --- a/tests/components/rflink/test_light.py +++ b/tests/components/rflink/test_light.py @@ -4,7 +4,6 @@ Test setup of RFLink lights component/platform. State tracking and control of RFLink switch devices. """ - from homeassistant.components.light import ATTR_BRIGHTNESS from homeassistant.components.rflink import EVENT_BUTTON_PRESSED from homeassistant.const import ( @@ -267,15 +266,11 @@ async def test_signal_repetitions_alternation(hass, monkeypatch): # setup mocking rflink module _, _, protocol, _ = await mock_rflink(hass, config, DOMAIN, monkeypatch) - hass.async_create_task( - hass.services.async_call( - DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test"} - ) + await hass.services.async_call( + DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test"} ) - hass.async_create_task( - hass.services.async_call( - DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test1"} - ) + await hass.services.async_call( + DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test1"} ) await hass.async_block_till_done() @@ -299,10 +294,8 @@ async def test_signal_repetitions_cancelling(hass, monkeypatch): # setup mocking rflink module _, _, protocol, _ = await mock_rflink(hass, config, DOMAIN, monkeypatch) - hass.async_create_task( - hass.services.async_call( - DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test"} - ) + await hass.services.async_call( + DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test"} ) hass.async_create_task( diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index c80b6eac193..8d28bc73b88 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -306,6 +306,30 @@ async def test_call_with_required_features(hass, mock_entities): assert test_service_mock.call_count == 1 +async def test_call_with_sync_func(hass, mock_entities): + """Test invoking sync service calls.""" + test_service_mock = Mock() + await service.entity_service_call( + hass, + [Mock(entities=mock_entities)], + test_service_mock, + ha.ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}), + ) + assert test_service_mock.call_count == 1 + + +async def test_call_with_sync_attr(hass, mock_entities): + """Test invoking sync service calls.""" + mock_entities["light.kitchen"].sync_method = Mock() + await service.entity_service_call( + hass, + [Mock(entities=mock_entities)], + "sync_method", + ha.ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}), + ) + assert mock_entities["light.kitchen"].sync_method.call_count == 1 + + async def test_call_context_user_not_exist(hass): """Check we don't allow deleted users to do things.""" with pytest.raises(exceptions.UnknownUser) as err: @@ -348,7 +372,7 @@ async def test_call_context_target_all(hass, mock_service_platform_call, mock_en ) assert len(mock_service_platform_call.mock_calls) == 1 - entities = mock_service_platform_call.mock_calls[0][1][2] + entities = mock_service_platform_call.mock_calls[0][1][3] assert entities == [mock_entities["light.kitchen"]] @@ -379,7 +403,7 @@ async def test_call_context_target_specific( ) assert len(mock_service_platform_call.mock_calls) == 1 - entities = mock_service_platform_call.mock_calls[0][1][2] + entities = mock_service_platform_call.mock_calls[0][1][3] assert entities == [mock_entities["light.kitchen"]] @@ -422,7 +446,7 @@ async def test_call_no_context_target_all( ) assert len(mock_service_platform_call.mock_calls) == 1 - entities = mock_service_platform_call.mock_calls[0][1][2] + entities = mock_service_platform_call.mock_calls[0][1][3] assert entities == list(mock_entities.values()) @@ -442,7 +466,7 @@ async def test_call_no_context_target_specific( ) assert len(mock_service_platform_call.mock_calls) == 1 - entities = mock_service_platform_call.mock_calls[0][1][2] + entities = mock_service_platform_call.mock_calls[0][1][3] assert entities == [mock_entities["light.kitchen"]] @@ -458,7 +482,7 @@ async def test_call_with_match_all( ) assert len(mock_service_platform_call.mock_calls) == 1 - entities = mock_service_platform_call.mock_calls[0][1][2] + entities = mock_service_platform_call.mock_calls[0][1][3] assert entities == [ mock_entities["light.kitchen"], mock_entities["light.living_room"], @@ -480,7 +504,7 @@ async def test_call_with_omit_entity_id( ) assert len(mock_service_platform_call.mock_calls) == 1 - entities = mock_service_platform_call.mock_calls[0][1][2] + entities = mock_service_platform_call.mock_calls[0][1][3] assert entities == []