diff --git a/homeassistant/components/binary_sensor/device_trigger.py b/homeassistant/components/binary_sensor/device_trigger.py index b87a761a7a1..0e41db85763 100644 --- a/homeassistant/components/binary_sensor/device_trigger.py +++ b/homeassistant/components/binary_sensor/device_trigger.py @@ -7,8 +7,9 @@ from homeassistant.components.device_automation.const import ( CONF_TURNED_ON, ) from homeassistant.components.homeassistant.triggers import state as state_trigger -from homeassistant.const import ATTR_DEVICE_CLASS, CONF_ENTITY_ID, CONF_FOR, CONF_TYPE +from homeassistant.const import CONF_ENTITY_ID, CONF_FOR, CONF_TYPE from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.entity import get_device_class from homeassistant.helpers.entity_registry import async_entries_for_device from . import ( @@ -220,10 +221,7 @@ async def async_get_triggers(hass, device_id): ] for entry in entries: - device_class = DEVICE_CLASS_NONE - state = hass.states.get(entry.entity_id) - if state: - device_class = state.attributes.get(ATTR_DEVICE_CLASS) + device_class = get_device_class(hass, entry.entity_id) or DEVICE_CLASS_NONE templates = ENTITY_TRIGGERS.get( device_class, ENTITY_TRIGGERS[DEVICE_CLASS_NONE] diff --git a/tests/components/binary_sensor/test_device_trigger.py b/tests/components/binary_sensor/test_device_trigger.py index 0e5cbcc1d70..1dbed7d19e1 100644 --- a/tests/components/binary_sensor/test_device_trigger.py +++ b/tests/components/binary_sensor/test_device_trigger.py @@ -78,6 +78,46 @@ async def test_get_triggers(hass, device_reg, entity_reg, enable_custom_integrat assert triggers == expected_triggers +async def test_get_triggers_no_state( + hass, device_reg, entity_reg, enable_custom_integrations +): + """Test we get the expected triggers from a binary_sensor.""" + platform = getattr(hass.components, f"test.{DOMAIN}") + platform.init() + entity_ids = {} + + config_entry = MockConfigEntry(domain="test", data={}) + config_entry.add_to_hass(hass) + device_entry = device_reg.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + ) + for device_class in DEVICE_CLASSES: + entity_ids[device_class] = entity_reg.async_get_or_create( + DOMAIN, + "test", + platform.ENTITIES[device_class].unique_id, + device_id=device_entry.id, + device_class=device_class, + ).entity_id + + await hass.async_block_till_done() + + expected_triggers = [ + { + "platform": "device", + "domain": DOMAIN, + "type": trigger["type"], + "device_id": device_entry.id, + "entity_id": entity_ids[device_class], + } + for device_class in DEVICE_CLASSES + for trigger in ENTITY_TRIGGERS[device_class] + ] + triggers = await async_get_device_automations(hass, "trigger", device_entry.id) + assert triggers == expected_triggers + + async def test_get_trigger_capabilities(hass, device_reg, entity_reg): """Test we get the expected capabilities from a binary_sensor trigger.""" config_entry = MockConfigEntry(domain="test", data={})