diff --git a/homeassistant/components/device_automation/__init__.py b/homeassistant/components/device_automation/__init__.py index 9d80ce169a9..67ef17dc379 100644 --- a/homeassistant/components/device_automation/__init__.py +++ b/homeassistant/components/device_automation/__init__.py @@ -228,7 +228,11 @@ async def _async_get_device_automations( return combined_results -async def _async_get_device_automation_capabilities(hass, automation_type, automation): +async def _async_get_device_automation_capabilities( + hass: HomeAssistant, + automation_type: DeviceAutomationType, + automation: Mapping[str, Any], +) -> dict[str, Any]: """List device automations.""" try: platform = await async_get_device_automation_platform( @@ -237,8 +241,6 @@ async def _async_get_device_automation_capabilities(hass, automation_type, autom except InvalidDeviceAutomationConfig: return {} - if isinstance(automation_type, str): # until tests pass DeviceAutomationType - automation_type = DeviceAutomationType[automation_type.upper()] function_name = automation_type.value.get_capabilities_func if not hasattr(platform, function_name): @@ -259,7 +261,7 @@ async def _async_get_device_automation_capabilities(hass, automation_type, autom extra_fields, custom_serializer=cv.custom_serializer ) - return capabilities + return capabilities # type: ignore[no-any-return] def handle_device_errors(func): diff --git a/tests/common.py b/tests/common.py index 327427eda6e..73b67a63ebc 100644 --- a/tests/common.py +++ b/tests/common.py @@ -70,7 +70,7 @@ CLIENT_REDIRECT_URI = "https://example.com/app/callback" async def async_get_device_automations( hass: HomeAssistant, - automation_type: device_automation.DeviceAutomationType | str, + automation_type: device_automation.DeviceAutomationType, device_id: str, ) -> Any: """Get a device automation for a single device id.""" diff --git a/tests/components/alarm_control_panel/test_device_action.py b/tests/components/alarm_control_panel/test_device_action.py index f4b0832ad97..5cbe9f256ba 100644 --- a/tests/components/alarm_control_panel/test_device_action.py +++ b/tests/components/alarm_control_panel/test_device_action.py @@ -170,7 +170,7 @@ async def test_get_action_capabilities( assert {action["type"] for action in actions} == set(expected_capabilities) for action in actions: capabilities = await async_get_device_automation_capabilities( - hass, "action", action + hass, DeviceAutomationType.ACTION, action ) assert capabilities == expected_capabilities[action["type"]] @@ -222,7 +222,7 @@ async def test_get_action_capabilities_arm_code( assert {action["type"] for action in actions} == set(expected_capabilities) for action in actions: capabilities = await async_get_device_automation_capabilities( - hass, "action", action + hass, DeviceAutomationType.ACTION, action ) assert capabilities == expected_capabilities[action["type"]] diff --git a/tests/components/alarm_control_panel/test_device_trigger.py b/tests/components/alarm_control_panel/test_device_trigger.py index e874b50baa0..c8082e415e0 100644 --- a/tests/components/alarm_control_panel/test_device_trigger.py +++ b/tests/components/alarm_control_panel/test_device_trigger.py @@ -151,7 +151,7 @@ async def test_get_trigger_capabilities(hass, device_reg, entity_reg): assert len(triggers) == 6 for trigger in triggers: capabilities = await async_get_device_automation_capabilities( - hass, "trigger", trigger + hass, DeviceAutomationType.TRIGGER, trigger ) assert capabilities == { "extra_fields": [ diff --git a/tests/components/binary_sensor/test_device_condition.py b/tests/components/binary_sensor/test_device_condition.py index 2e23fd799f2..fcef3cfd418 100644 --- a/tests/components/binary_sensor/test_device_condition.py +++ b/tests/components/binary_sensor/test_device_condition.py @@ -137,7 +137,7 @@ async def test_get_condition_capabilities(hass, device_reg, entity_reg): ) for condition in conditions: capabilities = await async_get_device_automation_capabilities( - hass, "condition", condition + hass, DeviceAutomationType.CONDITION, condition ) assert capabilities == expected_capabilities diff --git a/tests/components/binary_sensor/test_device_trigger.py b/tests/components/binary_sensor/test_device_trigger.py index 001af0b1e64..c4cd7df9d91 100644 --- a/tests/components/binary_sensor/test_device_trigger.py +++ b/tests/components/binary_sensor/test_device_trigger.py @@ -140,7 +140,7 @@ async def test_get_trigger_capabilities(hass, device_reg, entity_reg): ) for trigger in triggers: capabilities = await async_get_device_automation_capabilities( - hass, "trigger", trigger + hass, DeviceAutomationType.TRIGGER, trigger ) assert capabilities == expected_capabilities diff --git a/tests/components/cover/test_device_action.py b/tests/components/cover/test_device_action.py index 7954b3389e1..e1595089d2e 100644 --- a/tests/components/cover/test_device_action.py +++ b/tests/components/cover/test_device_action.py @@ -151,7 +151,7 @@ async def test_get_action_capabilities( assert action_types == {"open", "close", "stop", "open_tilt", "close_tilt"} for action in actions: capabilities = await async_get_device_automation_capabilities( - hass, "action", action + hass, DeviceAutomationType.ACTION, action ) assert capabilities == {"extra_fields": []} @@ -197,7 +197,7 @@ async def test_get_action_capabilities_set_pos( assert action_types == {"set_position"} for action in actions: capabilities = await async_get_device_automation_capabilities( - hass, "action", action + hass, DeviceAutomationType.ACTION, action ) if action["type"] == "set_position": assert capabilities == expected_capabilities @@ -246,7 +246,7 @@ async def test_get_action_capabilities_set_tilt_pos( assert action_types == {"open", "close", "set_tilt_position"} for action in actions: capabilities = await async_get_device_automation_capabilities( - hass, "action", action + hass, DeviceAutomationType.ACTION, action ) if action["type"] == "set_tilt_position": assert capabilities == expected_capabilities diff --git a/tests/components/cover/test_device_condition.py b/tests/components/cover/test_device_condition.py index b964ad8bc0d..8d6403f8b52 100644 --- a/tests/components/cover/test_device_condition.py +++ b/tests/components/cover/test_device_condition.py @@ -138,7 +138,7 @@ async def test_get_condition_capabilities( assert len(conditions) == 4 for condition in conditions: capabilities = await async_get_device_automation_capabilities( - hass, "condition", condition + hass, DeviceAutomationType.CONDITION, condition ) assert capabilities == {"extra_fields": []} @@ -189,7 +189,7 @@ async def test_get_condition_capabilities_set_pos( assert len(conditions) == 5 for condition in conditions: capabilities = await async_get_device_automation_capabilities( - hass, "condition", condition + hass, DeviceAutomationType.CONDITION, condition ) if condition["type"] == "is_position": assert capabilities == expected_capabilities @@ -243,7 +243,7 @@ async def test_get_condition_capabilities_set_tilt_pos( assert len(conditions) == 5 for condition in conditions: capabilities = await async_get_device_automation_capabilities( - hass, "condition", condition + hass, DeviceAutomationType.CONDITION, condition ) if condition["type"] == "is_tilt_position": assert capabilities == expected_capabilities diff --git a/tests/components/cover/test_device_trigger.py b/tests/components/cover/test_device_trigger.py index 323394e9fe3..3eac5d29b61 100644 --- a/tests/components/cover/test_device_trigger.py +++ b/tests/components/cover/test_device_trigger.py @@ -158,7 +158,7 @@ async def test_get_trigger_capabilities( assert len(triggers) == 4 for trigger in triggers: capabilities = await async_get_device_automation_capabilities( - hass, "trigger", trigger + hass, DeviceAutomationType.TRIGGER, trigger ) assert capabilities == { "extra_fields": [ @@ -213,7 +213,7 @@ async def test_get_trigger_capabilities_set_pos( assert len(triggers) == 5 for trigger in triggers: capabilities = await async_get_device_automation_capabilities( - hass, "trigger", trigger + hass, DeviceAutomationType.TRIGGER, trigger ) if trigger["type"] == "position": assert capabilities == expected_capabilities @@ -275,7 +275,7 @@ async def test_get_trigger_capabilities_set_tilt_pos( assert len(triggers) == 5 for trigger in triggers: capabilities = await async_get_device_automation_capabilities( - hass, "trigger", trigger + hass, DeviceAutomationType.TRIGGER, trigger ) if trigger["type"] == "tilt_position": assert capabilities == expected_capabilities diff --git a/tests/components/fan/test_device_trigger.py b/tests/components/fan/test_device_trigger.py index c58b9004cde..0d9edaf6fab 100644 --- a/tests/components/fan/test_device_trigger.py +++ b/tests/components/fan/test_device_trigger.py @@ -92,7 +92,7 @@ async def test_get_trigger_capabilities(hass, device_reg, entity_reg): ) for trigger in triggers: capabilities = await async_get_device_automation_capabilities( - hass, "trigger", trigger + hass, DeviceAutomationType.TRIGGER, trigger ) assert capabilities == expected_capabilities diff --git a/tests/components/light/test_device_action.py b/tests/components/light/test_device_action.py index 6d38b4784a7..ec47710a6f6 100644 --- a/tests/components/light/test_device_action.py +++ b/tests/components/light/test_device_action.py @@ -127,7 +127,7 @@ async def test_get_action_capabilities(hass, device_reg, entity_reg): assert action_types == {"turn_on", "toggle", "turn_off"} for action in actions: capabilities = await async_get_device_automation_capabilities( - hass, "action", action + hass, DeviceAutomationType.ACTION, action ) assert capabilities == {"extra_fields": []} @@ -135,7 +135,7 @@ async def test_get_action_capabilities(hass, device_reg, entity_reg): entity_reg.async_remove(entity_id) for action in actions: capabilities = await async_get_device_automation_capabilities( - hass, "action", action + hass, DeviceAutomationType.ACTION, action ) assert capabilities == {"extra_fields": []} @@ -273,7 +273,7 @@ async def test_get_action_capabilities_features( assert action_types == expected_actions for action in actions: capabilities = await async_get_device_automation_capabilities( - hass, "action", action + hass, DeviceAutomationType.ACTION, action ) expected = {"extra_fields": expected_capabilities.get(action["type"], [])} assert capabilities == expected diff --git a/tests/components/light/test_device_condition.py b/tests/components/light/test_device_condition.py index 7afd1272017..ba718b385fb 100644 --- a/tests/components/light/test_device_condition.py +++ b/tests/components/light/test_device_condition.py @@ -91,7 +91,7 @@ async def test_get_condition_capabilities(hass, device_reg, entity_reg): ) for condition in conditions: capabilities = await async_get_device_automation_capabilities( - hass, "condition", condition + hass, DeviceAutomationType.CONDITION, condition ) assert capabilities == expected_capabilities diff --git a/tests/components/light/test_device_trigger.py b/tests/components/light/test_device_trigger.py index 342a761e7c4..d6e906abd74 100644 --- a/tests/components/light/test_device_trigger.py +++ b/tests/components/light/test_device_trigger.py @@ -91,7 +91,7 @@ async def test_get_trigger_capabilities(hass, device_reg, entity_reg): ) for trigger in triggers: capabilities = await async_get_device_automation_capabilities( - hass, "trigger", trigger + hass, DeviceAutomationType.TRIGGER, trigger ) assert capabilities == expected_capabilities diff --git a/tests/components/lock/test_device_trigger.py b/tests/components/lock/test_device_trigger.py index 000ed8b44aa..cf4287a02be 100644 --- a/tests/components/lock/test_device_trigger.py +++ b/tests/components/lock/test_device_trigger.py @@ -116,7 +116,7 @@ async def test_get_trigger_capabilities(hass, device_reg, entity_reg): assert len(triggers) == 5 for trigger in triggers: capabilities = await async_get_device_automation_capabilities( - hass, "trigger", trigger + hass, DeviceAutomationType.TRIGGER, trigger ) assert capabilities == { "extra_fields": [ diff --git a/tests/components/media_player/test_device_trigger.py b/tests/components/media_player/test_device_trigger.py index 6440430e2d2..842184d62ef 100644 --- a/tests/components/media_player/test_device_trigger.py +++ b/tests/components/media_player/test_device_trigger.py @@ -91,7 +91,7 @@ async def test_get_trigger_capabilities(hass, device_reg, entity_reg): assert len(triggers) == 5 for trigger in triggers: capabilities = await async_get_device_automation_capabilities( - hass, "trigger", trigger + hass, DeviceAutomationType.TRIGGER, trigger ) assert capabilities == { "extra_fields": [