Add webhook trigger allowed_methods/local_only options (#66494)

Co-authored-by: Franck Nijhof <frenck@frenck.nl>
Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
Eric Severance 2023-04-14 03:49:12 -07:00 committed by GitHub
parent b23cedeae9
commit 94f35ea968
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 256 additions and 17 deletions

View File

@ -1,7 +1,7 @@
"""Webhooks for Home Assistant."""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Iterable
from http import HTTPStatus
from ipaddress import ip_address
import logging
@ -9,6 +9,7 @@ import secrets
from typing import TYPE_CHECKING, Any
from aiohttp import StreamReader
from aiohttp.hdrs import METH_GET, METH_HEAD, METH_POST, METH_PUT
from aiohttp.web import Request, Response
import voluptuous as vol
@ -25,6 +26,8 @@ _LOGGER = logging.getLogger(__name__)
DOMAIN = "webhook"
DEFAULT_METHODS = (METH_POST, METH_PUT)
SUPPORTED_METHODS = (METH_GET, METH_HEAD, METH_POST, METH_PUT)
URL_WEBHOOK_PATH = "/api/webhook/{webhook_id}"
@ -37,7 +40,8 @@ def async_register(
webhook_id: str,
handler: Callable[[HomeAssistant, str, Request], Awaitable[Response | None]],
*,
local_only=False,
local_only: bool | None = False,
allowed_methods: Iterable[str] | None = None,
) -> None:
"""Register a webhook."""
handlers = hass.data.setdefault(DOMAIN, {})
@ -45,11 +49,21 @@ def async_register(
if webhook_id in handlers:
raise ValueError("Handler is already defined!")
if allowed_methods is None:
allowed_methods = DEFAULT_METHODS
allowed_methods = frozenset(allowed_methods)
if not allowed_methods.issubset(SUPPORTED_METHODS):
raise ValueError(
f"Unexpected method: {allowed_methods.difference(SUPPORTED_METHODS)}"
)
handlers[webhook_id] = {
"domain": domain,
"name": name,
"handler": handler,
"local_only": local_only,
"allowed_methods": allowed_methods,
}
@ -90,16 +104,18 @@ async def async_handle_webhook(
"""Handle a webhook."""
handlers: dict[str, dict[str, Any]] = hass.data.setdefault(DOMAIN, {})
content_stream: StreamReader | MockStreamReader
if isinstance(request, MockRequest):
received_from = request.mock_source
content_stream = request.content
method_name = request.method
else:
received_from = request.remote
content_stream = request.content
method_name = request.method
# Always respond successfully to not give away if a hook exists or not.
if (webhook := handlers.get(webhook_id)) is None:
content_stream: StreamReader | MockStreamReader
if isinstance(request, MockRequest):
received_from = request.mock_source
content_stream = request.content
else:
received_from = request.remote
content_stream = request.content
_LOGGER.info(
"Received message for unregistered webhook %s from %s",
webhook_id,
@ -111,7 +127,21 @@ async def async_handle_webhook(
_LOGGER.debug("%s", content)
return Response(status=HTTPStatus.OK)
if webhook["local_only"]:
if method_name not in webhook["allowed_methods"]:
if method_name == METH_HEAD:
# Allow websites to verify that the URL exists.
return Response(status=HTTPStatus.OK)
_LOGGER.warning(
"Webhook %s only supports %s methods but %s was received from %s",
webhook_id,
",".join(webhook["allowed_methods"]),
method_name,
received_from,
)
return Response(status=HTTPStatus.METHOD_NOT_ALLOWED)
if webhook["local_only"] in (True, None) and not isinstance(request, MockRequest):
if TYPE_CHECKING:
assert isinstance(request, Request)
assert request.remote is not None
@ -123,7 +153,17 @@ async def async_handle_webhook(
if not network.is_local(remote):
_LOGGER.warning("Received remote request for local webhook %s", webhook_id)
return Response(status=HTTPStatus.OK)
if webhook["local_only"]:
return Response(status=HTTPStatus.OK)
if not webhook.get("warned_about_deprecation"):
webhook["warned_about_deprecation"] = True
_LOGGER.warning(
"Deprecation warning: "
"Webhook '%s' does not provide a value for local_only. "
"This webhook will be blocked after the 2023.7.0 release. "
"Use `local_only: false` to keep this webhook operating as-is",
webhook_id,
)
try:
response = await webhook["handler"](hass, webhook_id, request)
@ -157,6 +197,7 @@ class WebhookView(HomeAssistantView):
hass = request.app["hass"]
return await async_handle_webhook(hass, webhook_id, request)
get = _handle
head = _handle
post = _handle
put = _handle
@ -182,6 +223,7 @@ def websocket_list(
"domain": info["domain"],
"name": info["name"],
"local_only": info["local_only"],
"allowed_methods": sorted(info["allowed_methods"]),
}
for webhook_id, info in handlers.items()
]
@ -193,7 +235,7 @@ def websocket_list(
{
vol.Required("type"): "webhook/handle",
vol.Required("webhook_id"): str,
vol.Required("method"): vol.In(["GET", "POST", "PUT"]),
vol.Required("method"): vol.In(SUPPORTED_METHODS),
vol.Optional("body", default=""): str,
vol.Optional("headers", default={}): {str: str},
vol.Optional("query", default=""): str,

View File

@ -0,0 +1,8 @@
{
"issues": {
"trigger_missing_local_only": {
"title": "Update webhook trigger: {webhook_id}",
"description": "A choice needs to be made about whether the {webhook_id} webhook automation trigger is accessible from the internet. [Edit the automation]({edit}) \"{automation_name}\", (`{entity_id}`) and click the gear icon beside the Webhook ID to choose a value for 'Only accessible from the local network'"
}
}
}

View File

@ -0,0 +1,8 @@
{
"issues": {
"trigger_missing_local_only": {
"description": "A choice needs to be made about whether the {webhook_id} webhook automation trigger is accessible from the internet. Edit the {automation_name} automation, and click the gear icon beside the Webhook ID to choose a value for 'Only accessible from the local network'",
"title": "Update webhook trigger: {webhook_id}"
}
}
}

View File

@ -2,6 +2,7 @@
from __future__ import annotations
from dataclasses import dataclass
import logging
from aiohttp import hdrs
import voluptuous as vol
@ -9,17 +10,39 @@ import voluptuous as vol
from homeassistant.const import CONF_PLATFORM, CONF_WEBHOOK_ID
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.issue_registry import (
IssueSeverity,
async_create_issue,
async_delete_issue,
)
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
from homeassistant.helpers.typing import ConfigType
from . import DOMAIN, async_register, async_unregister
from . import (
DEFAULT_METHODS,
DOMAIN,
SUPPORTED_METHODS,
async_register,
async_unregister,
)
_LOGGER = logging.getLogger(__name__)
DEPENDENCIES = ("webhook",)
CONF_ALLOWED_METHODS = "allowed_methods"
CONF_LOCAL_ONLY = "local_only"
TRIGGER_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend(
{
vol.Required(CONF_PLATFORM): "webhook",
vol.Required(CONF_WEBHOOK_ID): cv.string,
vol.Optional(CONF_ALLOWED_METHODS): vol.All(
cv.ensure_list,
[vol.All(vol.Upper, vol.In(SUPPORTED_METHODS))],
vol.Unique(),
),
vol.Optional(CONF_LOCAL_ONLY): bool,
}
)
@ -62,6 +85,32 @@ async def async_attach_trigger(
) -> CALLBACK_TYPE:
"""Trigger based on incoming webhooks."""
webhook_id: str = config[CONF_WEBHOOK_ID]
local_only = config.get(CONF_LOCAL_ONLY)
issue_id: str | None = None
if local_only is None:
issue_id = f"trigger_missing_local_only_{webhook_id}"
variables = trigger_info["variables"] or {}
automation_info = variables.get("this", {})
automation_id = automation_info.get("attributes", {}).get("id")
automation_entity_id = automation_info.get("entity_id")
automation_name = trigger_info.get("name") or automation_entity_id
async_create_issue(
hass,
DOMAIN,
issue_id,
breaks_in_ha_version="2023.7.0",
is_fixable=False,
severity=IssueSeverity.WARNING,
learn_more_url="https://www.home-assistant.io/docs/automation/trigger/#webhook-trigger",
translation_key="trigger_missing_local_only",
translation_placeholders={
"webhook_id": webhook_id,
"automation_name": automation_name,
"entity_id": automation_entity_id,
"edit": f"/config/automation/edit/{automation_id}",
},
)
allowed_methods = config.get(CONF_ALLOWED_METHODS, DEFAULT_METHODS)
job = HassJob(action)
triggers: dict[str, list[TriggerInstance]] = hass.data.setdefault(
@ -75,6 +124,8 @@ async def async_attach_trigger(
trigger_info["name"],
webhook_id,
_handle_webhook,
local_only=local_only,
allowed_methods=allowed_methods,
)
triggers[webhook_id] = []
@ -84,6 +135,8 @@ async def async_attach_trigger(
@callback
def unregister():
"""Unregister webhook."""
if issue_id:
async_delete_issue(hass, DOMAIN, issue_id)
triggers[webhook_id].remove(trigger_instance)
if not triggers[webhook_id]:
async_unregister(hass, webhook_id)

View File

@ -81,6 +81,7 @@ async def fake_post_request_no_data(*args, **kwargs):
async def simulate_webhook(hass, webhook_id, response):
"""Simulate a webhook event."""
request = MockRequest(
method="POST",
content=bytes(json.dumps({**COMMON_RESPONSE, **response}), "utf-8"),
mock_source="test",
)

View File

@ -159,7 +159,9 @@ async def test_webhook_head(hass: HomeAssistant, mock_client) -> None:
"""Handle webhook."""
hooks.append(args)
webhook.async_register(hass, "test", "Test hook", webhook_id, handle)
webhook.async_register(
hass, "test", "Test hook", webhook_id, handle, allowed_methods=["HEAD"]
)
resp = await mock_client.head(f"/api/webhook/{webhook_id}")
assert resp.status == HTTPStatus.OK
@ -168,6 +170,58 @@ async def test_webhook_head(hass: HomeAssistant, mock_client) -> None:
assert hooks[0][1] == webhook_id
assert hooks[0][2].method == "HEAD"
# Test that status is HTTPStatus.OK even when HEAD is not allowed.
webhook.async_unregister(hass, webhook_id)
webhook.async_register(
hass, "test", "Test hook", webhook_id, handle, allowed_methods=["PUT"]
)
resp = await mock_client.head(f"/api/webhook/{webhook_id}")
assert resp.status == HTTPStatus.OK
assert len(hooks) == 1 # Should not have been called
async def test_webhook_get(hass, mock_client):
"""Test sending a get request to a webhook."""
hooks = []
webhook_id = webhook.async_generate_id()
async def handle(*args):
"""Handle webhook."""
hooks.append(args)
webhook.async_register(
hass, "test", "Test hook", webhook_id, handle, allowed_methods=["GET"]
)
resp = await mock_client.get(f"/api/webhook/{webhook_id}")
assert resp.status == HTTPStatus.OK
assert len(hooks) == 1
assert hooks[0][0] is hass
assert hooks[0][1] == webhook_id
assert hooks[0][2].method == "GET"
# Test that status is HTTPStatus.METHOD_NOT_ALLOWED even when GET is not allowed.
webhook.async_unregister(hass, webhook_id)
webhook.async_register(
hass, "test", "Test hook", webhook_id, handle, allowed_methods=["PUT"]
)
resp = await mock_client.get(f"/api/webhook/{webhook_id}")
assert resp.status == HTTPStatus.METHOD_NOT_ALLOWED
assert len(hooks) == 1 # Should not have been called
async def test_webhook_not_allowed_method(hass):
"""Test that an exception is raised if an unsupported method is used."""
webhook_id = webhook.async_generate_id()
async def handle(*args):
pass
with pytest.raises(ValueError):
webhook.async_register(
hass, "test", "Test hook", webhook_id, handle, allowed_methods=["PATCH"]
)
async def test_webhook_local_only(hass: HomeAssistant, mock_client) -> None:
"""Test posting a webhook with local only."""
@ -211,7 +265,15 @@ async def test_listing_webhook(
client = await hass_ws_client(hass, hass_access_token)
webhook.async_register(hass, "test", "Test hook", "my-id", None)
webhook.async_register(hass, "test", "Test hook", "my-2", None, local_only=True)
webhook.async_register(
hass,
"test",
"Test hook",
"my-2",
None,
local_only=True,
allowed_methods=["GET"],
)
await client.send_json({"id": 5, "type": "webhook/list"})
@ -224,12 +286,14 @@ async def test_listing_webhook(
"domain": "test",
"name": "Test hook",
"local_only": False,
"allowed_methods": ["POST", "PUT"],
},
{
"webhook_id": "my-2",
"domain": "test",
"name": "Test hook",
"local_only": True,
"allowed_methods": ["GET"],
},
]

View File

@ -1,4 +1,5 @@
"""The tests for the webhook automation trigger."""
from ipaddress import ip_address
from unittest.mock import patch
import pytest
@ -77,7 +78,11 @@ async def test_webhook_post(
"automation",
{
"automation": {
"trigger": {"platform": "webhook", "webhook_id": "post_webhook"},
"trigger": {
"platform": "webhook",
"webhook_id": "post_webhook",
"local_only": True,
},
"action": {
"event": "test_success",
"event_data_template": {"hello": "yo {{ trigger.data.hello }}"},
@ -95,6 +100,64 @@ async def test_webhook_post(
assert len(events) == 1
assert events[0].data["hello"] == "yo world"
# Request from remote IP
with patch(
"homeassistant.components.webhook.ip_address",
return_value=ip_address("123.123.123.123"),
):
await client.post("/api/webhook/post_webhook", data={"hello": "world"})
# No hook received
await hass.async_block_till_done()
assert len(events) == 1
async def test_webhook_allowed_methods_internet(hass, hass_client_no_auth):
"""Test the webhook obeys allowed_methods and local_only options."""
events = []
@callback
def store_event(event):
"""Help store events."""
events.append(event)
hass.bus.async_listen("test_success", store_event)
assert await async_setup_component(
hass,
"automation",
{
"automation": {
"trigger": {
"platform": "webhook",
"webhook_id": "post_webhook",
"allowed_methods": "PUT",
# Enable after 2023.4.0: "local_only": False,
},
"action": {
"event": "test_success",
"event_data_template": {"hello": "yo {{ trigger.data.hello }}"},
},
}
},
)
await hass.async_block_till_done()
client = await hass_client_no_auth()
await client.post("/api/webhook/post_webhook", data={"hello": "world"})
await hass.async_block_till_done()
assert len(events) == 0
# Request from remote IP
with patch(
"homeassistant.components.webhook.ip_address",
return_value=ip_address("123.123.123.123"),
):
await client.put("/api/webhook/post_webhook", data={"hello": "world"})
await hass.async_block_till_done()
assert len(events) == 1
async def test_webhook_query(
hass: HomeAssistant, hass_client_no_auth: ClientSessionGenerator