From 65ab8cbc717c6481707c2141f56a1a9bf34a5dd3 Mon Sep 17 00:00:00 2001 From: Franck Nijhof Date: Sun, 29 Nov 2020 00:06:32 +0100 Subject: [PATCH] Add support for multiple tags and devices in tag trigger (#43098) Co-authored-by: Paulus Schoutsen --- homeassistant/components/tag/trigger.py | 45 +++++++++++------- tests/components/tag/test_trigger.py | 63 ++++++++++++++++++++++++- 2 files changed, 90 insertions(+), 18 deletions(-) diff --git a/homeassistant/components/tag/trigger.py b/homeassistant/components/tag/trigger.py index 8da9baa5aaa..9803bd56afe 100644 --- a/homeassistant/components/tag/trigger.py +++ b/homeassistant/components/tag/trigger.py @@ -1,8 +1,8 @@ """Support for tag triggers.""" import voluptuous as vol -from homeassistant.components.homeassistant.triggers import event as event_trigger from homeassistant.const import CONF_PLATFORM +from homeassistant.core import HassJob from homeassistant.helpers import config_validation as cv from .const import DEVICE_ID, DOMAIN, EVENT_TAG_SCANNED, TAG_ID @@ -10,28 +10,39 @@ from .const import DEVICE_ID, DOMAIN, EVENT_TAG_SCANNED, TAG_ID TRIGGER_SCHEMA = vol.Schema( { vol.Required(CONF_PLATFORM): DOMAIN, - vol.Required(TAG_ID): cv.string, - vol.Optional(DEVICE_ID): cv.string, + vol.Required(TAG_ID): vol.All(cv.ensure_list, [cv.string]), + vol.Optional(DEVICE_ID): vol.All(cv.ensure_list, [cv.string]), } ) async def async_attach_trigger(hass, config, action, automation_info): """Listen for tag_scanned events based on configuration.""" - tag_id = config.get(TAG_ID) - device_id = config.get(DEVICE_ID) - event_data = {TAG_ID: tag_id} + tag_ids = set(config[TAG_ID]) + device_ids = set(config[DEVICE_ID]) if DEVICE_ID in config else None - if device_id: - event_data[DEVICE_ID] = device_id + job = HassJob(action) - event_config = { - event_trigger.CONF_PLATFORM: "event", - event_trigger.CONF_EVENT_TYPE: EVENT_TAG_SCANNED, - event_trigger.CONF_EVENT_DATA: event_data, - } - event_config = event_trigger.TRIGGER_SCHEMA(event_config) + async def handle_event(event): + """Listen for tag scan events and calls the action when data matches.""" + if event.data.get(TAG_ID) not in tag_ids or ( + device_ids is not None and event.data.get(DEVICE_ID) not in device_ids + ): + return - return await event_trigger.async_attach_trigger( - hass, event_config, action, automation_info, platform_type=DOMAIN - ) + task = hass.async_run_hass_job( + job, + { + "trigger": { + "platform": DOMAIN, + "event": event, + "description": "Tag scanned", + } + }, + event.context, + ) + + if task: + await task + + return hass.bus.async_listen(EVENT_TAG_SCANNED, handle_event) diff --git a/tests/components/tag/test_trigger.py b/tests/components/tag/test_trigger.py index 3a83c8e5d2b..9a97d95e7d5 100644 --- a/tests/components/tag/test_trigger.py +++ b/tests/components/tag/test_trigger.py @@ -4,7 +4,8 @@ import pytest import homeassistant.components.automation as automation from homeassistant.components.tag import async_scan_tag -from homeassistant.components.tag.const import DOMAIN, TAG_ID +from homeassistant.components.tag.const import DEVICE_ID, DOMAIN, TAG_ID +from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_OFF from homeassistant.setup import async_setup_component from tests.common import async_mock_service @@ -45,6 +46,7 @@ async def test_triggers(hass, tag_setup, calls): { automation.DOMAIN: [ { + "alias": "test", "trigger": {"platform": DOMAIN, TAG_ID: "abc123"}, "action": { "service": "test.automation", @@ -63,6 +65,18 @@ async def test_triggers(hass, tag_setup, calls): assert len(calls) == 1 assert calls[0].data["message"] == "service called" + await hass.services.async_call( + automation.DOMAIN, + SERVICE_TURN_OFF, + {ATTR_ENTITY_ID: "automation.test"}, + blocking=True, + ) + + await async_scan_tag(hass, "abc123", None) + await hass.async_block_till_done() + + assert len(calls) == 1 + async def test_exception_bad_trigger(hass, calls, caplog): """Test for exception on event triggers firing.""" @@ -84,3 +98,50 @@ async def test_exception_bad_trigger(hass, calls, caplog): ) await hass.async_block_till_done() assert "Invalid config for [automation]" in caplog.text + + +async def test_multiple_tags_and_devices_trigger(hass, tag_setup, calls): + """Test multiple tags and devices triggers.""" + assert await tag_setup() + assert await async_setup_component( + hass, + automation.DOMAIN, + { + automation.DOMAIN: [ + { + "trigger": { + "platform": DOMAIN, + TAG_ID: ["abc123", "def456"], + DEVICE_ID: ["ghi789", "jkl0123"], + }, + "action": { + "service": "test.automation", + "data": {"message": "service called"}, + }, + } + ] + }, + ) + + await hass.async_block_till_done() + + # Should not trigger + await async_scan_tag(hass, tag_id="abc123", device_id=None) + await async_scan_tag(hass, tag_id="abc123", device_id="invalid") + await hass.async_block_till_done() + + # Should trigger + await async_scan_tag(hass, tag_id="abc123", device_id="ghi789") + await hass.async_block_till_done() + await async_scan_tag(hass, tag_id="abc123", device_id="jkl0123") + await hass.async_block_till_done() + await async_scan_tag(hass, "def456", device_id="ghi789") + await hass.async_block_till_done() + await async_scan_tag(hass, "def456", device_id="jkl0123") + await hass.async_block_till_done() + + assert len(calls) == 4 + assert calls[0].data["message"] == "service called" + assert calls[1].data["message"] == "service called" + assert calls[2].data["message"] == "service called" + assert calls[3].data["message"] == "service called"