From 5b7a65c5eaff17a6a4eab1a4a822f995157bdb0d Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 2 Feb 2020 15:36:39 -0800 Subject: [PATCH] Fix service annotations (#31402) * Fix service annotations * Filter area_id from service data * Fix services not accepting entities * Typo --- .../components/input_select/__init__.py | 17 +++-- .../components/media_player/__init__.py | 65 ++++++++++++++----- homeassistant/helpers/config_validation.py | 4 +- homeassistant/helpers/service.py | 9 ++- tests/helpers/test_service.py | 12 +++- 5 files changed, 78 insertions(+), 29 deletions(-) diff --git a/homeassistant/components/input_select/__init__.py b/homeassistant/components/input_select/__init__.py index 26a07e600f3..6044375d8a8 100644 --- a/homeassistant/components/input_select/__init__.py +++ b/homeassistant/components/input_select/__init__.py @@ -143,11 +143,15 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool: ) component.async_register_entity_service( - SERVICE_SELECT_NEXT, {}, lambda entity, call: entity.async_offset_index(1) + SERVICE_SELECT_NEXT, + {}, + callback(lambda entity, call: entity.async_offset_index(1)), ) component.async_register_entity_service( - SERVICE_SELECT_PREVIOUS, {}, lambda entity, call: entity.async_offset_index(-1) + SERVICE_SELECT_PREVIOUS, + {}, + callback(lambda entity, call: entity.async_offset_index(-1)), ) component.async_register_entity_service( @@ -248,7 +252,8 @@ class InputSelect(RestoreEntity): """Return unique id for the entity.""" return self._config[CONF_ID] - async def async_select_option(self, option): + @callback + def async_select_option(self, option): """Select new option.""" if option not in self._options: _LOGGER.warning( @@ -260,14 +265,16 @@ class InputSelect(RestoreEntity): self._current_option = option self.async_write_ha_state() - async def async_offset_index(self, offset): + @callback + def async_offset_index(self, offset): """Offset current index.""" current_index = self._options.index(self._current_option) new_index = (current_index + offset) % len(self._options) self._current_option = self._options[new_index] self.async_write_ha_state() - async def async_set_options(self, options): + @callback + def async_set_options(self, options): """Set options.""" self._current_option = options[0] self._config[CONF_OPTIONS] = options diff --git a/homeassistant/components/media_player/__init__.py b/homeassistant/components/media_player/__init__.py index 28951df545a..2911a143a3c 100644 --- a/homeassistant/components/media_player/__init__.py +++ b/homeassistant/components/media_player/__init__.py @@ -173,6 +173,23 @@ SCHEMA_WEBSOCKET_GET_THUMBNAIL = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.exten ) +def _rename_keys(**keys): + """Create validator that renames keys. + + Necessary because the service schema names do not match the command parameters. + + Async friendly. + """ + + def rename(value): + for to_key, from_key in keys.items(): + if from_key in value: + value[to_key] = value.pop(from_key) + return value + + return rename + + async def async_setup(hass, config): """Track states and offer events for media_players.""" component = hass.data[DOMAIN] = EntityComponent( @@ -238,30 +255,39 @@ async def async_setup(hass, config): ) component.async_register_entity_service( SERVICE_VOLUME_SET, - {vol.Required(ATTR_MEDIA_VOLUME_LEVEL): cv.small_float}, - lambda entity, call: entity.async_set_volume_level( - volume=call.data[ATTR_MEDIA_VOLUME_LEVEL] + vol.All( + cv.make_entity_service_schema( + {vol.Required(ATTR_MEDIA_VOLUME_LEVEL): cv.small_float} + ), + _rename_keys(volume=ATTR_MEDIA_VOLUME_LEVEL), ), + "async_set_volume_level", [SUPPORT_VOLUME_SET], ) component.async_register_entity_service( SERVICE_VOLUME_MUTE, - {vol.Required(ATTR_MEDIA_VOLUME_MUTED): cv.boolean}, - lambda entity, call: entity.async_mute_volume( - mute=call.data[ATTR_MEDIA_VOLUME_MUTED] + vol.All( + cv.make_entity_service_schema( + {vol.Required(ATTR_MEDIA_VOLUME_MUTED): cv.boolean} + ), + _rename_keys(mute=ATTR_MEDIA_VOLUME_MUTED), ), + "async_mute_volume", [SUPPORT_VOLUME_MUTE], ) component.async_register_entity_service( SERVICE_MEDIA_SEEK, - { - vol.Required(ATTR_MEDIA_SEEK_POSITION): vol.All( - vol.Coerce(float), vol.Range(min=0) - ) - }, - lambda entity, call: entity.async_media_seek( - position=call.data[ATTR_MEDIA_SEEK_POSITION] + vol.All( + cv.make_entity_service_schema( + { + vol.Required(ATTR_MEDIA_SEEK_POSITION): vol.All( + vol.Coerce(float), vol.Range(min=0) + ) + } + ), + _rename_keys(position=ATTR_MEDIA_SEEK_POSITION), ), + "async_media_seek", [SUPPORT_SEEK], ) component.async_register_entity_service( @@ -278,12 +304,15 @@ async def async_setup(hass, config): ) component.async_register_entity_service( SERVICE_PLAY_MEDIA, - MEDIA_PLAYER_PLAY_MEDIA_SCHEMA, - lambda entity, call: entity.async_play_media( - media_type=call.data[ATTR_MEDIA_CONTENT_TYPE], - media_id=call.data[ATTR_MEDIA_CONTENT_ID], - enqueue=call.data.get(ATTR_MEDIA_ENQUEUE), + vol.All( + cv.make_entity_service_schema(MEDIA_PLAYER_PLAY_MEDIA_SCHEMA), + _rename_keys( + media_type=ATTR_MEDIA_CONTENT_TYPE, + media_id=ATTR_MEDIA_CONTENT_ID, + enqueue=ATTR_MEDIA_ENQUEUE, + ), ), + "async_play_media", [SUPPORT_PLAY_MEDIA], ) component.async_register_entity_service( diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index e357a2ba622..852948220de 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -724,6 +724,8 @@ PLATFORM_SCHEMA = vol.Schema( PLATFORM_SCHEMA_BASE = PLATFORM_SCHEMA.extend({}, extra=vol.ALLOW_EXTRA) +ENTITY_SERVICE_FIELDS = (ATTR_ENTITY_ID, ATTR_AREA_ID) + def make_entity_service_schema( schema: dict, *, extra: int = vol.PREVENT_EXTRA @@ -738,7 +740,7 @@ def make_entity_service_schema( }, extra=extra, ), - has_at_least_one_key(ATTR_ENTITY_ID, ATTR_AREA_ID), + has_at_least_one_key(*ENTITY_SERVICE_FIELDS), ) diff --git a/homeassistant/helpers/service.py b/homeassistant/helpers/service.py index 36bfd9c8cb0..b30cab3fbd4 100644 --- a/homeassistant/helpers/service.py +++ b/homeassistant/helpers/service.py @@ -283,7 +283,11 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non # If the service function is a string, we'll pass it the service call data if isinstance(func, str): - data = {key: val for key, val in call.data.items() if key != ATTR_ENTITY_ID} + data = { + key: val + for key, val in call.data.items() + if key not in cv.ENTITY_SERVICE_FIELDS + } # If the service function is not a string, we pass the service call else: data = call @@ -323,6 +327,7 @@ async def entity_service_call(hass, platforms, func, call, required_features=Non for platform in platforms: platform_entities = [] for entity in platform.entities.values(): + if entity.entity_id not in entity_ids: continue @@ -380,7 +385,7 @@ async def _handle_service_platform_call( if asyncio.iscoroutine(result): _LOGGER.error( - "Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to component author.", + "Service %s for %s incorrectly returns a coroutine object. Await result instead in service handler. Report bug to integration author.", func, entity.entity_id, ) diff --git a/tests/helpers/test_service.py b/tests/helpers/test_service.py index 8d28bc73b88..d90842d1b71 100644 --- a/tests/helpers/test_service.py +++ b/tests/helpers/test_service.py @@ -320,14 +320,20 @@ async def test_call_with_sync_func(hass, mock_entities): async def test_call_with_sync_attr(hass, mock_entities): """Test invoking sync service calls.""" - mock_entities["light.kitchen"].sync_method = Mock() + mock_method = mock_entities["light.kitchen"].sync_method = Mock() await service.entity_service_call( hass, [Mock(entities=mock_entities)], "sync_method", - ha.ServiceCall("test_domain", "test_service", {"entity_id": "light.kitchen"}), + ha.ServiceCall( + "test_domain", + "test_service", + {"entity_id": "light.kitchen", "area_id": "abcd"}, + ), ) - assert mock_entities["light.kitchen"].sync_method.call_count == 1 + assert mock_method.call_count == 1 + # We pass empty kwargs because both entity_id and area_id are filtered out + assert mock_method.mock_calls[0][2] == {} async def test_call_context_user_not_exist(hass):