From e1a0a855d52478de2d1b8eb65a8e22bf2fe54afa Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 14 Aug 2024 11:44:38 +0200 Subject: [PATCH] Support None schema in EntityComponent.async_register_entity_service (#123867) --- homeassistant/helpers/entity_component.py | 4 +-- tests/helpers/test_entity_component.py | 30 +++++++++++++++-------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/homeassistant/helpers/entity_component.py b/homeassistant/helpers/entity_component.py index 0034eb1c6fc..c8bcda0eef2 100644 --- a/homeassistant/helpers/entity_component.py +++ b/homeassistant/helpers/entity_component.py @@ -258,13 +258,13 @@ class EntityComponent(Generic[_EntityT]): def async_register_entity_service( self, name: str, - schema: VolDictType | VolSchemaType, + schema: VolDictType | VolSchemaType | None, func: str | Callable[..., Any], required_features: list[int] | None = None, supports_response: SupportsResponse = SupportsResponse.NONE, ) -> None: """Register an entity service.""" - if isinstance(schema, dict): + if schema is None or isinstance(schema, dict): schema = cv.make_entity_service_schema(schema) service_func: str | HassJob[..., Any] diff --git a/tests/helpers/test_entity_component.py b/tests/helpers/test_entity_component.py index 3f34305b39d..0c09c9d75f7 100644 --- a/tests/helpers/test_entity_component.py +++ b/tests/helpers/test_entity_component.py @@ -495,7 +495,19 @@ async def test_extract_all_use_match_all( ) not in caplog.text -async def test_register_entity_service(hass: HomeAssistant) -> None: +@pytest.mark.parametrize( + ("schema", "service_data"), + [ + ({"some": str}, {"some": "data"}), + ({}, {}), + (None, {}), + ], +) +async def test_register_entity_service( + hass: HomeAssistant, + schema: dict | None, + service_data: dict, +) -> None: """Test registering an enttiy service and calling it.""" entity = MockEntity(entity_id=f"{DOMAIN}.entity") calls = [] @@ -510,9 +522,7 @@ async def test_register_entity_service(hass: HomeAssistant) -> None: await component.async_setup({}) await component.async_add_entities([entity]) - component.async_register_entity_service( - "hello", {"some": str}, "async_called_by_service" - ) + component.async_register_entity_service("hello", schema, "async_called_by_service") with pytest.raises(vol.Invalid): await hass.services.async_call( @@ -524,24 +534,24 @@ async def test_register_entity_service(hass: HomeAssistant) -> None: assert len(calls) == 0 await hass.services.async_call( - DOMAIN, "hello", {"entity_id": entity.entity_id, "some": "data"}, blocking=True + DOMAIN, "hello", {"entity_id": entity.entity_id} | service_data, blocking=True ) assert len(calls) == 1 - assert calls[0] == {"some": "data"} + assert calls[0] == service_data await hass.services.async_call( - DOMAIN, "hello", {"entity_id": ENTITY_MATCH_ALL, "some": "data"}, blocking=True + DOMAIN, "hello", {"entity_id": ENTITY_MATCH_ALL} | service_data, blocking=True ) assert len(calls) == 2 - assert calls[1] == {"some": "data"} + assert calls[1] == service_data await hass.services.async_call( - DOMAIN, "hello", {"entity_id": ENTITY_MATCH_NONE, "some": "data"}, blocking=True + DOMAIN, "hello", {"entity_id": ENTITY_MATCH_NONE} | service_data, blocking=True ) assert len(calls) == 2 await hass.services.async_call( - DOMAIN, "hello", {"area_id": ENTITY_MATCH_NONE, "some": "data"}, blocking=True + DOMAIN, "hello", {"area_id": ENTITY_MATCH_NONE} | service_data, blocking=True ) assert len(calls) == 2