Fix service helper not handling sync methods (#31254)

* Fix service helper not handling sync methods

* Add legacy support for returning coroutine objects

* Fix tests

* Fix tests

* Convert demo cover to async
This commit is contained in:
Paulus Schoutsen 2020-01-29 16:27:25 -08:00 committed by GitHub
parent 111fc1fa8e
commit 01dad31adc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 73 additions and 46 deletions

View File

@ -6,7 +6,8 @@ from homeassistant.components.cover import (
SUPPORT_OPEN, SUPPORT_OPEN,
CoverDevice, CoverDevice,
) )
from homeassistant.helpers.event import track_utc_time_change from homeassistant.core import callback
from homeassistant.helpers.event import async_track_utc_time_change
from . import DOMAIN from . import DOMAIN
@ -131,21 +132,21 @@ class DemoCover(CoverDevice):
return self._supported_features return self._supported_features
return super().supported_features return super().supported_features
def close_cover(self, **kwargs): async def async_close_cover(self, **kwargs):
"""Close the cover.""" """Close the cover."""
if self._position == 0: if self._position == 0:
return return
if self._position is None: if self._position is None:
self._closed = True self._closed = True
self.schedule_update_ha_state() self.async_write_ha_state()
return return
self._is_closing = True self._is_closing = True
self._listen_cover() self._listen_cover()
self._requested_closing = True self._requested_closing = True
self.schedule_update_ha_state() self.async_write_ha_state()
def close_cover_tilt(self, **kwargs): async def async_close_cover_tilt(self, **kwargs):
"""Close the cover tilt.""" """Close the cover tilt."""
if self._tilt_position in (0, None): if self._tilt_position in (0, None):
return return
@ -153,21 +154,21 @@ class DemoCover(CoverDevice):
self._listen_cover_tilt() self._listen_cover_tilt()
self._requested_closing_tilt = True self._requested_closing_tilt = True
def open_cover(self, **kwargs): async def async_open_cover(self, **kwargs):
"""Open the cover.""" """Open the cover."""
if self._position == 100: if self._position == 100:
return return
if self._position is None: if self._position is None:
self._closed = False self._closed = False
self.schedule_update_ha_state() self.async_write_ha_state()
return return
self._is_opening = True self._is_opening = True
self._listen_cover() self._listen_cover()
self._requested_closing = False self._requested_closing = False
self.schedule_update_ha_state() self.async_write_ha_state()
def open_cover_tilt(self, **kwargs): async def async_open_cover_tilt(self, **kwargs):
"""Open the cover tilt.""" """Open the cover tilt."""
if self._tilt_position in (100, None): if self._tilt_position in (100, None):
return return
@ -175,7 +176,7 @@ class DemoCover(CoverDevice):
self._listen_cover_tilt() self._listen_cover_tilt()
self._requested_closing_tilt = False self._requested_closing_tilt = False
def set_cover_position(self, **kwargs): async def async_set_cover_position(self, **kwargs):
"""Move the cover to a specific position.""" """Move the cover to a specific position."""
position = kwargs.get(ATTR_POSITION) position = kwargs.get(ATTR_POSITION)
self._set_position = round(position, -1) self._set_position = round(position, -1)
@ -185,7 +186,7 @@ class DemoCover(CoverDevice):
self._listen_cover() self._listen_cover()
self._requested_closing = position < self._position self._requested_closing = position < self._position
def set_cover_tilt_position(self, **kwargs): async def async_set_cover_tilt_position(self, **kwargs):
"""Move the cover til to a specific position.""" """Move the cover til to a specific position."""
tilt_position = kwargs.get(ATTR_TILT_POSITION) tilt_position = kwargs.get(ATTR_TILT_POSITION)
self._set_tilt_position = round(tilt_position, -1) self._set_tilt_position = round(tilt_position, -1)
@ -195,7 +196,7 @@ class DemoCover(CoverDevice):
self._listen_cover_tilt() self._listen_cover_tilt()
self._requested_closing_tilt = tilt_position < self._tilt_position self._requested_closing_tilt = tilt_position < self._tilt_position
def stop_cover(self, **kwargs): async def async_stop_cover(self, **kwargs):
"""Stop the cover.""" """Stop the cover."""
self._is_closing = False self._is_closing = False
self._is_opening = False self._is_opening = False
@ -206,7 +207,7 @@ class DemoCover(CoverDevice):
self._unsub_listener_cover = None self._unsub_listener_cover = None
self._set_position = None self._set_position = None
def stop_cover_tilt(self, **kwargs): async def async_stop_cover_tilt(self, **kwargs):
"""Stop the cover tilt.""" """Stop the cover tilt."""
if self._tilt_position is None: if self._tilt_position is None:
return return
@ -216,14 +217,15 @@ class DemoCover(CoverDevice):
self._unsub_listener_cover_tilt = None self._unsub_listener_cover_tilt = None
self._set_tilt_position = None self._set_tilt_position = None
@callback
def _listen_cover(self): def _listen_cover(self):
"""Listen for changes in cover.""" """Listen for changes in cover."""
if self._unsub_listener_cover is None: if self._unsub_listener_cover is None:
self._unsub_listener_cover = track_utc_time_change( self._unsub_listener_cover = async_track_utc_time_change(
self.hass, self._time_changed_cover self.hass, self._time_changed_cover
) )
def _time_changed_cover(self, now): async def _time_changed_cover(self, now):
"""Track time changes.""" """Track time changes."""
if self._requested_closing: if self._requested_closing:
self._position -= 10 self._position -= 10
@ -231,20 +233,20 @@ class DemoCover(CoverDevice):
self._position += 10 self._position += 10
if self._position in (100, 0, self._set_position): if self._position in (100, 0, self._set_position):
self.stop_cover() await self.async_stop_cover()
self._closed = self.current_cover_position <= 0 self._closed = self.current_cover_position <= 0
self.async_write_ha_state()
self.schedule_update_ha_state() @callback
def _listen_cover_tilt(self): def _listen_cover_tilt(self):
"""Listen for changes in cover tilt.""" """Listen for changes in cover tilt."""
if self._unsub_listener_cover_tilt is None: if self._unsub_listener_cover_tilt is None:
self._unsub_listener_cover_tilt = track_utc_time_change( self._unsub_listener_cover_tilt = async_track_utc_time_change(
self.hass, self._time_changed_cover_tilt self.hass, self._time_changed_cover_tilt
) )
def _time_changed_cover_tilt(self, now): async def _time_changed_cover_tilt(self, now):
"""Track time changes.""" """Track time changes."""
if self._requested_closing_tilt: if self._requested_closing_tilt:
self._tilt_position -= 10 self._tilt_position -= 10
@ -252,6 +254,6 @@ class DemoCover(CoverDevice):
self._tilt_position += 10 self._tilt_position += 10
if self._tilt_position in (100, 0, self._set_tilt_position): if self._tilt_position in (100, 0, self._set_tilt_position):
self.stop_cover_tilt() await self.async_stop_cover_tilt()
self.schedule_update_ha_state() self.async_write_ha_state()

View File

@ -1,6 +1,6 @@
"""Service calling related helpers.""" """Service calling related helpers."""
import asyncio import asyncio
from functools import wraps from functools import partial, wraps
import logging import logging
from typing import Callable from typing import Callable
@ -339,7 +339,7 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
tasks = [ tasks = [
_handle_service_platform_call( _handle_service_platform_call(
func, data, entities, call.context, required_features hass, func, data, entities, call.context, required_features
) )
for platform, entities in zip(platforms, platforms_entities) for platform, entities in zip(platforms, platforms_entities)
] ]
@ -352,7 +352,7 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non
async def _handle_service_platform_call( async def _handle_service_platform_call(
func, data, entities, context, required_features hass, func, data, entities, context, required_features
): ):
"""Handle a function call.""" """Handle a function call."""
tasks = [] tasks = []
@ -370,9 +370,17 @@ async def _handle_service_platform_call(
entity.async_set_context(context) entity.async_set_context(context)
if isinstance(func, str): if isinstance(func, str):
await getattr(entity, func)(**data) result = await hass.async_add_job(partial(getattr(entity, func), **data))
else: else:
await func(entity, data) result = await hass.async_add_job(func, entity, data)
if asyncio.iscoroutine(result):
_LOGGER.error(
"Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to component author.",
func,
entity.entity_id,
)
await result
if entity.should_poll: if entity.should_poll:
tasks.append(entity.async_update_ha_state(True)) tasks.append(entity.async_update_ha_state(True))

View File

@ -4,7 +4,6 @@ Test setup of RFLink lights component/platform. State tracking and
control of RFLink switch devices. control of RFLink switch devices.
""" """
from homeassistant.components.light import ATTR_BRIGHTNESS from homeassistant.components.light import ATTR_BRIGHTNESS
from homeassistant.components.rflink import EVENT_BUTTON_PRESSED from homeassistant.components.rflink import EVENT_BUTTON_PRESSED
from homeassistant.const import ( from homeassistant.const import (
@ -267,16 +266,12 @@ async def test_signal_repetitions_alternation(hass, monkeypatch):
# setup mocking rflink module # setup mocking rflink module
_, _, protocol, _ = await mock_rflink(hass, config, DOMAIN, monkeypatch) _, _, protocol, _ = await mock_rflink(hass, config, DOMAIN, monkeypatch)
hass.async_create_task( await hass.services.async_call(
hass.services.async_call(
DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test"} DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test"}
) )
) await hass.services.async_call(
hass.async_create_task(
hass.services.async_call(
DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test1"} DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test1"}
) )
)
await hass.async_block_till_done() await hass.async_block_till_done()
@ -299,11 +294,9 @@ async def test_signal_repetitions_cancelling(hass, monkeypatch):
# setup mocking rflink module # setup mocking rflink module
_, _, protocol, _ = await mock_rflink(hass, config, DOMAIN, monkeypatch) _, _, protocol, _ = await mock_rflink(hass, config, DOMAIN, monkeypatch)
hass.async_create_task( await hass.services.async_call(
hass.services.async_call(
DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test"} DOMAIN, SERVICE_TURN_OFF, {ATTR_ENTITY_ID: DOMAIN + ".test"}
) )
)
hass.async_create_task( hass.async_create_task(
hass.services.async_call( hass.services.async_call(

View File

@ -306,6 +306,30 @@ async def test_call_with_required_features(hass, mock_entities):
assert test_service_mock.call_count == 1 assert test_service_mock.call_count == 1
async def test_call_with_sync_func(hass, mock_entities):
"""Test invoking sync service calls."""
test_service_mock = Mock()
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
test_service_mock,
ha.ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}),
)
assert test_service_mock.call_count == 1
async def test_call_with_sync_attr(hass, mock_entities):
"""Test invoking sync service calls."""
mock_entities["light.kitchen"].sync_method = Mock()
await service.entity_service_call(
hass,
[Mock(entities=mock_entities)],
"sync_method",
ha.ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}),
)
assert mock_entities["light.kitchen"].sync_method.call_count == 1
async def test_call_context_user_not_exist(hass): async def test_call_context_user_not_exist(hass):
"""Check we don't allow deleted users to do things.""" """Check we don't allow deleted users to do things."""
with pytest.raises(exceptions.UnknownUser) as err: with pytest.raises(exceptions.UnknownUser) as err:
@ -348,7 +372,7 @@ async def test_call_context_target_all(hass, mock_service_platform_call, mock_en
) )
assert len(mock_service_platform_call.mock_calls) == 1 assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][2] entities = mock_service_platform_call.mock_calls[0][1][3]
assert entities == [mock_entities["light.kitchen"]] assert entities == [mock_entities["light.kitchen"]]
@ -379,7 +403,7 @@ async def test_call_context_target_specific(
) )
assert len(mock_service_platform_call.mock_calls) == 1 assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][2] entities = mock_service_platform_call.mock_calls[0][1][3]
assert entities == [mock_entities["light.kitchen"]] assert entities == [mock_entities["light.kitchen"]]
@ -422,7 +446,7 @@ async def test_call_no_context_target_all(
) )
assert len(mock_service_platform_call.mock_calls) == 1 assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][2] entities = mock_service_platform_call.mock_calls[0][1][3]
assert entities == list(mock_entities.values()) assert entities == list(mock_entities.values())
@ -442,7 +466,7 @@ async def test_call_no_context_target_specific(
) )
assert len(mock_service_platform_call.mock_calls) == 1 assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][2] entities = mock_service_platform_call.mock_calls[0][1][3]
assert entities == [mock_entities["light.kitchen"]] assert entities == [mock_entities["light.kitchen"]]
@ -458,7 +482,7 @@ async def test_call_with_match_all(
) )
assert len(mock_service_platform_call.mock_calls) == 1 assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][2] entities = mock_service_platform_call.mock_calls[0][1][3]
assert entities == [ assert entities == [
mock_entities["light.kitchen"], mock_entities["light.kitchen"],
mock_entities["light.living_room"], mock_entities["light.living_room"],
@ -480,7 +504,7 @@ async def test_call_with_omit_entity_id(
) )
assert len(mock_service_platform_call.mock_calls) == 1 assert len(mock_service_platform_call.mock_calls) == 1
entities = mock_service_platform_call.mock_calls[0][1][2] entities = mock_service_platform_call.mock_calls[0][1][3]
assert entities == [] assert entities == []