diff --git a/homeassistant/components/webhook/__init__.py b/homeassistant/components/webhook/__init__.py index b067321a1c0..e58890a1d18 100644 --- a/homeassistant/components/webhook/__init__.py +++ b/homeassistant/components/webhook/__init__.py @@ -1,7 +1,7 @@ """Webhooks for Home Assistant.""" from __future__ import annotations -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Iterable from http import HTTPStatus from ipaddress import ip_address import logging @@ -9,6 +9,7 @@ import secrets from typing import TYPE_CHECKING, Any from aiohttp import StreamReader +from aiohttp.hdrs import METH_GET, METH_HEAD, METH_POST, METH_PUT from aiohttp.web import Request, Response import voluptuous as vol @@ -25,6 +26,8 @@ _LOGGER = logging.getLogger(__name__) DOMAIN = "webhook" +DEFAULT_METHODS = (METH_POST, METH_PUT) +SUPPORTED_METHODS = (METH_GET, METH_HEAD, METH_POST, METH_PUT) URL_WEBHOOK_PATH = "/api/webhook/{webhook_id}" @@ -37,7 +40,8 @@ def async_register( webhook_id: str, handler: Callable[[HomeAssistant, str, Request], Awaitable[Response | None]], *, - local_only=False, + local_only: bool | None = False, + allowed_methods: Iterable[str] | None = None, ) -> None: """Register a webhook.""" handlers = hass.data.setdefault(DOMAIN, {}) @@ -45,11 +49,21 @@ def async_register( if webhook_id in handlers: raise ValueError("Handler is already defined!") + if allowed_methods is None: + allowed_methods = DEFAULT_METHODS + allowed_methods = frozenset(allowed_methods) + + if not allowed_methods.issubset(SUPPORTED_METHODS): + raise ValueError( + f"Unexpected method: {allowed_methods.difference(SUPPORTED_METHODS)}" + ) + handlers[webhook_id] = { "domain": domain, "name": name, "handler": handler, "local_only": local_only, + "allowed_methods": allowed_methods, } @@ -90,16 +104,18 @@ async def async_handle_webhook( """Handle a webhook.""" handlers: dict[str, dict[str, Any]] = hass.data.setdefault(DOMAIN, {}) + content_stream: StreamReader | MockStreamReader + if isinstance(request, MockRequest): + received_from = request.mock_source + content_stream = request.content + method_name = request.method + else: + received_from = request.remote + content_stream = request.content + method_name = request.method + # Always respond successfully to not give away if a hook exists or not. if (webhook := handlers.get(webhook_id)) is None: - content_stream: StreamReader | MockStreamReader - if isinstance(request, MockRequest): - received_from = request.mock_source - content_stream = request.content - else: - received_from = request.remote - content_stream = request.content - _LOGGER.info( "Received message for unregistered webhook %s from %s", webhook_id, @@ -111,7 +127,21 @@ async def async_handle_webhook( _LOGGER.debug("%s", content) return Response(status=HTTPStatus.OK) - if webhook["local_only"]: + if method_name not in webhook["allowed_methods"]: + if method_name == METH_HEAD: + # Allow websites to verify that the URL exists. + return Response(status=HTTPStatus.OK) + + _LOGGER.warning( + "Webhook %s only supports %s methods but %s was received from %s", + webhook_id, + ",".join(webhook["allowed_methods"]), + method_name, + received_from, + ) + return Response(status=HTTPStatus.METHOD_NOT_ALLOWED) + + if webhook["local_only"] in (True, None) and not isinstance(request, MockRequest): if TYPE_CHECKING: assert isinstance(request, Request) assert request.remote is not None @@ -123,7 +153,17 @@ async def async_handle_webhook( if not network.is_local(remote): _LOGGER.warning("Received remote request for local webhook %s", webhook_id) - return Response(status=HTTPStatus.OK) + if webhook["local_only"]: + return Response(status=HTTPStatus.OK) + if not webhook.get("warned_about_deprecation"): + webhook["warned_about_deprecation"] = True + _LOGGER.warning( + "Deprecation warning: " + "Webhook '%s' does not provide a value for local_only. " + "This webhook will be blocked after the 2023.7.0 release. " + "Use `local_only: false` to keep this webhook operating as-is", + webhook_id, + ) try: response = await webhook["handler"](hass, webhook_id, request) @@ -157,6 +197,7 @@ class WebhookView(HomeAssistantView): hass = request.app["hass"] return await async_handle_webhook(hass, webhook_id, request) + get = _handle head = _handle post = _handle put = _handle @@ -182,6 +223,7 @@ def websocket_list( "domain": info["domain"], "name": info["name"], "local_only": info["local_only"], + "allowed_methods": sorted(info["allowed_methods"]), } for webhook_id, info in handlers.items() ] @@ -193,7 +235,7 @@ def websocket_list( { vol.Required("type"): "webhook/handle", vol.Required("webhook_id"): str, - vol.Required("method"): vol.In(["GET", "POST", "PUT"]), + vol.Required("method"): vol.In(SUPPORTED_METHODS), vol.Optional("body", default=""): str, vol.Optional("headers", default={}): {str: str}, vol.Optional("query", default=""): str, diff --git a/homeassistant/components/webhook/strings.json b/homeassistant/components/webhook/strings.json new file mode 100644 index 00000000000..53b932727d0 --- /dev/null +++ b/homeassistant/components/webhook/strings.json @@ -0,0 +1,8 @@ +{ + "issues": { + "trigger_missing_local_only": { + "title": "Update webhook trigger: {webhook_id}", + "description": "A choice needs to be made about whether the {webhook_id} webhook automation trigger is accessible from the internet. [Edit the automation]({edit}) \"{automation_name}\", (`{entity_id}`) and click the gear icon beside the Webhook ID to choose a value for 'Only accessible from the local network'" + } + } +} diff --git a/homeassistant/components/webhook/translations/en.json b/homeassistant/components/webhook/translations/en.json new file mode 100644 index 00000000000..619ef88203f --- /dev/null +++ b/homeassistant/components/webhook/translations/en.json @@ -0,0 +1,8 @@ +{ + "issues": { + "trigger_missing_local_only": { + "description": "A choice needs to be made about whether the {webhook_id} webhook automation trigger is accessible from the internet. Edit the {automation_name} automation, and click the gear icon beside the Webhook ID to choose a value for 'Only accessible from the local network'", + "title": "Update webhook trigger: {webhook_id}" + } + } +} \ No newline at end of file diff --git a/homeassistant/components/webhook/trigger.py b/homeassistant/components/webhook/trigger.py index cb1a6cb4eb6..48e484db4a1 100644 --- a/homeassistant/components/webhook/trigger.py +++ b/homeassistant/components/webhook/trigger.py @@ -2,6 +2,7 @@ from __future__ import annotations from dataclasses import dataclass +import logging from aiohttp import hdrs import voluptuous as vol @@ -9,17 +10,39 @@ import voluptuous as vol from homeassistant.const import CONF_PLATFORM, CONF_WEBHOOK_ID from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback import homeassistant.helpers.config_validation as cv +from homeassistant.helpers.issue_registry import ( + IssueSeverity, + async_create_issue, + async_delete_issue, +) from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo from homeassistant.helpers.typing import ConfigType -from . import DOMAIN, async_register, async_unregister +from . import ( + DEFAULT_METHODS, + DOMAIN, + SUPPORTED_METHODS, + async_register, + async_unregister, +) + +_LOGGER = logging.getLogger(__name__) DEPENDENCIES = ("webhook",) +CONF_ALLOWED_METHODS = "allowed_methods" +CONF_LOCAL_ONLY = "local_only" + TRIGGER_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend( { vol.Required(CONF_PLATFORM): "webhook", vol.Required(CONF_WEBHOOK_ID): cv.string, + vol.Optional(CONF_ALLOWED_METHODS): vol.All( + cv.ensure_list, + [vol.All(vol.Upper, vol.In(SUPPORTED_METHODS))], + vol.Unique(), + ), + vol.Optional(CONF_LOCAL_ONLY): bool, } ) @@ -62,6 +85,32 @@ async def async_attach_trigger( ) -> CALLBACK_TYPE: """Trigger based on incoming webhooks.""" webhook_id: str = config[CONF_WEBHOOK_ID] + local_only = config.get(CONF_LOCAL_ONLY) + issue_id: str | None = None + if local_only is None: + issue_id = f"trigger_missing_local_only_{webhook_id}" + variables = trigger_info["variables"] or {} + automation_info = variables.get("this", {}) + automation_id = automation_info.get("attributes", {}).get("id") + automation_entity_id = automation_info.get("entity_id") + automation_name = trigger_info.get("name") or automation_entity_id + async_create_issue( + hass, + DOMAIN, + issue_id, + breaks_in_ha_version="2023.7.0", + is_fixable=False, + severity=IssueSeverity.WARNING, + learn_more_url="https://www.home-assistant.io/docs/automation/trigger/#webhook-trigger", + translation_key="trigger_missing_local_only", + translation_placeholders={ + "webhook_id": webhook_id, + "automation_name": automation_name, + "entity_id": automation_entity_id, + "edit": f"/config/automation/edit/{automation_id}", + }, + ) + allowed_methods = config.get(CONF_ALLOWED_METHODS, DEFAULT_METHODS) job = HassJob(action) triggers: dict[str, list[TriggerInstance]] = hass.data.setdefault( @@ -75,6 +124,8 @@ async def async_attach_trigger( trigger_info["name"], webhook_id, _handle_webhook, + local_only=local_only, + allowed_methods=allowed_methods, ) triggers[webhook_id] = [] @@ -84,6 +135,8 @@ async def async_attach_trigger( @callback def unregister(): """Unregister webhook.""" + if issue_id: + async_delete_issue(hass, DOMAIN, issue_id) triggers[webhook_id].remove(trigger_instance) if not triggers[webhook_id]: async_unregister(hass, webhook_id) diff --git a/tests/components/netatmo/common.py b/tests/components/netatmo/common.py index 375dce4e723..a3f7dfcb9d2 100644 --- a/tests/components/netatmo/common.py +++ b/tests/components/netatmo/common.py @@ -81,6 +81,7 @@ async def fake_post_request_no_data(*args, **kwargs): async def simulate_webhook(hass, webhook_id, response): """Simulate a webhook event.""" request = MockRequest( + method="POST", content=bytes(json.dumps({**COMMON_RESPONSE, **response}), "utf-8"), mock_source="test", ) diff --git a/tests/components/webhook/test_init.py b/tests/components/webhook/test_init.py index e9d3d6e38b8..580366c16ea 100644 --- a/tests/components/webhook/test_init.py +++ b/tests/components/webhook/test_init.py @@ -159,7 +159,9 @@ async def test_webhook_head(hass: HomeAssistant, mock_client) -> None: """Handle webhook.""" hooks.append(args) - webhook.async_register(hass, "test", "Test hook", webhook_id, handle) + webhook.async_register( + hass, "test", "Test hook", webhook_id, handle, allowed_methods=["HEAD"] + ) resp = await mock_client.head(f"/api/webhook/{webhook_id}") assert resp.status == HTTPStatus.OK @@ -168,6 +170,58 @@ async def test_webhook_head(hass: HomeAssistant, mock_client) -> None: assert hooks[0][1] == webhook_id assert hooks[0][2].method == "HEAD" + # Test that status is HTTPStatus.OK even when HEAD is not allowed. + webhook.async_unregister(hass, webhook_id) + webhook.async_register( + hass, "test", "Test hook", webhook_id, handle, allowed_methods=["PUT"] + ) + resp = await mock_client.head(f"/api/webhook/{webhook_id}") + assert resp.status == HTTPStatus.OK + assert len(hooks) == 1 # Should not have been called + + +async def test_webhook_get(hass, mock_client): + """Test sending a get request to a webhook.""" + hooks = [] + webhook_id = webhook.async_generate_id() + + async def handle(*args): + """Handle webhook.""" + hooks.append(args) + + webhook.async_register( + hass, "test", "Test hook", webhook_id, handle, allowed_methods=["GET"] + ) + + resp = await mock_client.get(f"/api/webhook/{webhook_id}") + assert resp.status == HTTPStatus.OK + assert len(hooks) == 1 + assert hooks[0][0] is hass + assert hooks[0][1] == webhook_id + assert hooks[0][2].method == "GET" + + # Test that status is HTTPStatus.METHOD_NOT_ALLOWED even when GET is not allowed. + webhook.async_unregister(hass, webhook_id) + webhook.async_register( + hass, "test", "Test hook", webhook_id, handle, allowed_methods=["PUT"] + ) + resp = await mock_client.get(f"/api/webhook/{webhook_id}") + assert resp.status == HTTPStatus.METHOD_NOT_ALLOWED + assert len(hooks) == 1 # Should not have been called + + +async def test_webhook_not_allowed_method(hass): + """Test that an exception is raised if an unsupported method is used.""" + webhook_id = webhook.async_generate_id() + + async def handle(*args): + pass + + with pytest.raises(ValueError): + webhook.async_register( + hass, "test", "Test hook", webhook_id, handle, allowed_methods=["PATCH"] + ) + async def test_webhook_local_only(hass: HomeAssistant, mock_client) -> None: """Test posting a webhook with local only.""" @@ -211,7 +265,15 @@ async def test_listing_webhook( client = await hass_ws_client(hass, hass_access_token) webhook.async_register(hass, "test", "Test hook", "my-id", None) - webhook.async_register(hass, "test", "Test hook", "my-2", None, local_only=True) + webhook.async_register( + hass, + "test", + "Test hook", + "my-2", + None, + local_only=True, + allowed_methods=["GET"], + ) await client.send_json({"id": 5, "type": "webhook/list"}) @@ -224,12 +286,14 @@ async def test_listing_webhook( "domain": "test", "name": "Test hook", "local_only": False, + "allowed_methods": ["POST", "PUT"], }, { "webhook_id": "my-2", "domain": "test", "name": "Test hook", "local_only": True, + "allowed_methods": ["GET"], }, ] diff --git a/tests/components/webhook/test_trigger.py b/tests/components/webhook/test_trigger.py index 1912a962cd0..c2788deca30 100644 --- a/tests/components/webhook/test_trigger.py +++ b/tests/components/webhook/test_trigger.py @@ -1,4 +1,5 @@ """The tests for the webhook automation trigger.""" +from ipaddress import ip_address from unittest.mock import patch import pytest @@ -77,7 +78,11 @@ async def test_webhook_post( "automation", { "automation": { - "trigger": {"platform": "webhook", "webhook_id": "post_webhook"}, + "trigger": { + "platform": "webhook", + "webhook_id": "post_webhook", + "local_only": True, + }, "action": { "event": "test_success", "event_data_template": {"hello": "yo {{ trigger.data.hello }}"}, @@ -95,6 +100,64 @@ async def test_webhook_post( assert len(events) == 1 assert events[0].data["hello"] == "yo world" + # Request from remote IP + with patch( + "homeassistant.components.webhook.ip_address", + return_value=ip_address("123.123.123.123"), + ): + await client.post("/api/webhook/post_webhook", data={"hello": "world"}) + # No hook received + await hass.async_block_till_done() + assert len(events) == 1 + + +async def test_webhook_allowed_methods_internet(hass, hass_client_no_auth): + """Test the webhook obeys allowed_methods and local_only options.""" + events = [] + + @callback + def store_event(event): + """Help store events.""" + events.append(event) + + hass.bus.async_listen("test_success", store_event) + + assert await async_setup_component( + hass, + "automation", + { + "automation": { + "trigger": { + "platform": "webhook", + "webhook_id": "post_webhook", + "allowed_methods": "PUT", + # Enable after 2023.4.0: "local_only": False, + }, + "action": { + "event": "test_success", + "event_data_template": {"hello": "yo {{ trigger.data.hello }}"}, + }, + } + }, + ) + await hass.async_block_till_done() + + client = await hass_client_no_auth() + + await client.post("/api/webhook/post_webhook", data={"hello": "world"}) + await hass.async_block_till_done() + + assert len(events) == 0 + + # Request from remote IP + with patch( + "homeassistant.components.webhook.ip_address", + return_value=ip_address("123.123.123.123"), + ): + await client.put("/api/webhook/post_webhook", data={"hello": "world"}) + await hass.async_block_till_done() + assert len(events) == 1 + async def test_webhook_query( hass: HomeAssistant, hass_client_no_auth: ClientSessionGenerator