diff --git a/homeassistant/components/webhook/trigger.py b/homeassistant/components/webhook/trigger.py index 3f790b1ec42..498a7363a61 100644 --- a/homeassistant/components/webhook/trigger.py +++ b/homeassistant/components/webhook/trigger.py @@ -1,5 +1,7 @@ """Offer webhook triggered automation rules.""" -from functools import partial +from __future__ import annotations + +from dataclasses import dataclass from aiohttp import hdrs import voluptuous as vol @@ -13,7 +15,7 @@ from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback import homeassistant.helpers.config_validation as cv from homeassistant.helpers.typing import ConfigType -from . import async_register, async_unregister +from . import DOMAIN, async_register, async_unregister # mypy: allow-untyped-defs @@ -26,20 +28,35 @@ TRIGGER_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend( } ) +WEBHOOK_TRIGGERS = f"{DOMAIN}_triggers" -async def _handle_webhook(job, trigger_data, hass, webhook_id, request): + +@dataclass +class TriggerInstance: + """Attached trigger settings.""" + + automation_info: AutomationTriggerInfo + job: HassJob + + +async def _handle_webhook(hass, webhook_id, request): """Handle incoming webhook.""" - result = {"platform": "webhook", "webhook_id": webhook_id} + base_result = {"platform": "webhook", "webhook_id": webhook_id} if "json" in request.headers.get(hdrs.CONTENT_TYPE, ""): - result["json"] = await request.json() + base_result["json"] = await request.json() else: - result["data"] = await request.post() + base_result["data"] = await request.post() - result["query"] = request.query - result["description"] = "webhook" - result.update(**trigger_data) - hass.async_run_hass_job(job, {"trigger": result}) + base_result["query"] = request.query + base_result["description"] = "webhook" + + triggers: dict[str, list[TriggerInstance]] = hass.data.setdefault( + WEBHOOK_TRIGGERS, {} + ) + for trigger in triggers[webhook_id]: + result = {**base_result, **trigger.automation_info["trigger_data"]} + hass.async_run_hass_job(trigger.job, {"trigger": result}) async def async_attach_trigger( @@ -49,20 +66,32 @@ async def async_attach_trigger( automation_info: AutomationTriggerInfo, ) -> CALLBACK_TYPE: """Trigger based on incoming webhooks.""" - trigger_data = automation_info["trigger_data"] webhook_id: str = config[CONF_WEBHOOK_ID] job = HassJob(action) - async_register( - hass, - automation_info["domain"], - automation_info["name"], - webhook_id, - partial(_handle_webhook, job, trigger_data), + + triggers: dict[str, list[TriggerInstance]] = hass.data.setdefault( + WEBHOOK_TRIGGERS, {} ) + if webhook_id not in triggers: + async_register( + hass, + automation_info["domain"], + automation_info["name"], + webhook_id, + _handle_webhook, + ) + triggers[webhook_id] = [] + + trigger_instance = TriggerInstance(automation_info, job) + triggers[webhook_id].append(trigger_instance) + @callback def unregister(): """Unregister webhook.""" - async_unregister(hass, webhook_id) + triggers[webhook_id].remove(trigger_instance) + if not triggers[webhook_id]: + async_unregister(hass, webhook_id) + triggers.pop(webhook_id) return unregister diff --git a/tests/components/mobile_app/test_webhook.py b/tests/components/mobile_app/test_webhook.py index 0bc237b1c11..b7b95dff392 100644 --- a/tests/components/mobile_app/test_webhook.py +++ b/tests/components/mobile_app/test_webhook.py @@ -840,7 +840,7 @@ async def test_webhook_handle_scan_tag(hass, create_registrations, webhook_clien @callback def store_event(event): - """Helepr to store events.""" + """Help store events.""" events.append(event) hass.bus.async_listen("tag_scanned", store_event) diff --git a/tests/components/webhook/test_trigger.py b/tests/components/webhook/test_trigger.py index 2deac022b1e..e8d88845f5a 100644 --- a/tests/components/webhook/test_trigger.py +++ b/tests/components/webhook/test_trigger.py @@ -23,7 +23,7 @@ async def test_webhook_json(hass, hass_client_no_auth): @callback def store_event(event): - """Helepr to store events.""" + """Help store events.""" events.append(event) hass.bus.async_listen("test_success", store_event) @@ -62,7 +62,7 @@ async def test_webhook_post(hass, hass_client_no_auth): @callback def store_event(event): - """Helepr to store events.""" + """Help store events.""" events.append(event) hass.bus.async_listen("test_success", store_event) @@ -97,7 +97,7 @@ async def test_webhook_query(hass, hass_client_no_auth): @callback def store_event(event): - """Helepr to store events.""" + """Help store events.""" events.append(event) hass.bus.async_listen("test_success", store_event) @@ -126,13 +126,68 @@ async def test_webhook_query(hass, hass_client_no_auth): assert events[0].data["hello"] == "yo world" +async def test_webhook_multiple(hass, hass_client_no_auth): + """Test triggering multiple triggers with a POST webhook.""" + events1 = [] + events2 = [] + + @callback + def store_event1(event): + """Help store events.""" + events1.append(event) + + @callback + def store_event2(event): + """Help store events.""" + events2.append(event) + + hass.bus.async_listen("test_success1", store_event1) + hass.bus.async_listen("test_success2", store_event2) + + assert await async_setup_component( + hass, + "automation", + { + "automation": [ + { + "trigger": {"platform": "webhook", "webhook_id": "post_webhook"}, + "action": { + "event": "test_success1", + "event_data_template": {"hello": "yo {{ trigger.data.hello }}"}, + }, + }, + { + "trigger": {"platform": "webhook", "webhook_id": "post_webhook"}, + "action": { + "event": "test_success2", + "event_data_template": { + "hello": "yo2 {{ trigger.data.hello }}" + }, + }, + }, + ] + }, + ) + await hass.async_block_till_done() + + client = await hass_client_no_auth() + + await client.post("/api/webhook/post_webhook", data={"hello": "world"}) + await hass.async_block_till_done() + + assert len(events1) == 1 + assert events1[0].data["hello"] == "yo world" + assert len(events2) == 1 + assert events2[0].data["hello"] == "yo2 world" + + async def test_webhook_reload(hass, hass_client_no_auth): """Test reloading a webhook.""" events = [] @callback def store_event(event): - """Helepr to store events.""" + """Help store events.""" events.append(event) hass.bus.async_listen("test_success", store_event)