mirror of
https://github.com/home-assistant/core.git
synced 2025-04-22 16:27:56 +00:00
Add webhook trigger allowed_methods/local_only options (#66494)
Co-authored-by: Franck Nijhof <frenck@frenck.nl> Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
parent
b23cedeae9
commit
94f35ea968
@ -1,7 +1,7 @@
|
||||
"""Webhooks for Home Assistant."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
from http import HTTPStatus
|
||||
from ipaddress import ip_address
|
||||
import logging
|
||||
@ -9,6 +9,7 @@ import secrets
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from aiohttp import StreamReader
|
||||
from aiohttp.hdrs import METH_GET, METH_HEAD, METH_POST, METH_PUT
|
||||
from aiohttp.web import Request, Response
|
||||
import voluptuous as vol
|
||||
|
||||
@ -25,6 +26,8 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DOMAIN = "webhook"
|
||||
|
||||
DEFAULT_METHODS = (METH_POST, METH_PUT)
|
||||
SUPPORTED_METHODS = (METH_GET, METH_HEAD, METH_POST, METH_PUT)
|
||||
URL_WEBHOOK_PATH = "/api/webhook/{webhook_id}"
|
||||
|
||||
|
||||
@ -37,7 +40,8 @@ def async_register(
|
||||
webhook_id: str,
|
||||
handler: Callable[[HomeAssistant, str, Request], Awaitable[Response | None]],
|
||||
*,
|
||||
local_only=False,
|
||||
local_only: bool | None = False,
|
||||
allowed_methods: Iterable[str] | None = None,
|
||||
) -> None:
|
||||
"""Register a webhook."""
|
||||
handlers = hass.data.setdefault(DOMAIN, {})
|
||||
@ -45,11 +49,21 @@ def async_register(
|
||||
if webhook_id in handlers:
|
||||
raise ValueError("Handler is already defined!")
|
||||
|
||||
if allowed_methods is None:
|
||||
allowed_methods = DEFAULT_METHODS
|
||||
allowed_methods = frozenset(allowed_methods)
|
||||
|
||||
if not allowed_methods.issubset(SUPPORTED_METHODS):
|
||||
raise ValueError(
|
||||
f"Unexpected method: {allowed_methods.difference(SUPPORTED_METHODS)}"
|
||||
)
|
||||
|
||||
handlers[webhook_id] = {
|
||||
"domain": domain,
|
||||
"name": name,
|
||||
"handler": handler,
|
||||
"local_only": local_only,
|
||||
"allowed_methods": allowed_methods,
|
||||
}
|
||||
|
||||
|
||||
@ -90,16 +104,18 @@ async def async_handle_webhook(
|
||||
"""Handle a webhook."""
|
||||
handlers: dict[str, dict[str, Any]] = hass.data.setdefault(DOMAIN, {})
|
||||
|
||||
content_stream: StreamReader | MockStreamReader
|
||||
if isinstance(request, MockRequest):
|
||||
received_from = request.mock_source
|
||||
content_stream = request.content
|
||||
method_name = request.method
|
||||
else:
|
||||
received_from = request.remote
|
||||
content_stream = request.content
|
||||
method_name = request.method
|
||||
|
||||
# Always respond successfully to not give away if a hook exists or not.
|
||||
if (webhook := handlers.get(webhook_id)) is None:
|
||||
content_stream: StreamReader | MockStreamReader
|
||||
if isinstance(request, MockRequest):
|
||||
received_from = request.mock_source
|
||||
content_stream = request.content
|
||||
else:
|
||||
received_from = request.remote
|
||||
content_stream = request.content
|
||||
|
||||
_LOGGER.info(
|
||||
"Received message for unregistered webhook %s from %s",
|
||||
webhook_id,
|
||||
@ -111,7 +127,21 @@ async def async_handle_webhook(
|
||||
_LOGGER.debug("%s", content)
|
||||
return Response(status=HTTPStatus.OK)
|
||||
|
||||
if webhook["local_only"]:
|
||||
if method_name not in webhook["allowed_methods"]:
|
||||
if method_name == METH_HEAD:
|
||||
# Allow websites to verify that the URL exists.
|
||||
return Response(status=HTTPStatus.OK)
|
||||
|
||||
_LOGGER.warning(
|
||||
"Webhook %s only supports %s methods but %s was received from %s",
|
||||
webhook_id,
|
||||
",".join(webhook["allowed_methods"]),
|
||||
method_name,
|
||||
received_from,
|
||||
)
|
||||
return Response(status=HTTPStatus.METHOD_NOT_ALLOWED)
|
||||
|
||||
if webhook["local_only"] in (True, None) and not isinstance(request, MockRequest):
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(request, Request)
|
||||
assert request.remote is not None
|
||||
@ -123,7 +153,17 @@ async def async_handle_webhook(
|
||||
|
||||
if not network.is_local(remote):
|
||||
_LOGGER.warning("Received remote request for local webhook %s", webhook_id)
|
||||
return Response(status=HTTPStatus.OK)
|
||||
if webhook["local_only"]:
|
||||
return Response(status=HTTPStatus.OK)
|
||||
if not webhook.get("warned_about_deprecation"):
|
||||
webhook["warned_about_deprecation"] = True
|
||||
_LOGGER.warning(
|
||||
"Deprecation warning: "
|
||||
"Webhook '%s' does not provide a value for local_only. "
|
||||
"This webhook will be blocked after the 2023.7.0 release. "
|
||||
"Use `local_only: false` to keep this webhook operating as-is",
|
||||
webhook_id,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await webhook["handler"](hass, webhook_id, request)
|
||||
@ -157,6 +197,7 @@ class WebhookView(HomeAssistantView):
|
||||
hass = request.app["hass"]
|
||||
return await async_handle_webhook(hass, webhook_id, request)
|
||||
|
||||
get = _handle
|
||||
head = _handle
|
||||
post = _handle
|
||||
put = _handle
|
||||
@ -182,6 +223,7 @@ def websocket_list(
|
||||
"domain": info["domain"],
|
||||
"name": info["name"],
|
||||
"local_only": info["local_only"],
|
||||
"allowed_methods": sorted(info["allowed_methods"]),
|
||||
}
|
||||
for webhook_id, info in handlers.items()
|
||||
]
|
||||
@ -193,7 +235,7 @@ def websocket_list(
|
||||
{
|
||||
vol.Required("type"): "webhook/handle",
|
||||
vol.Required("webhook_id"): str,
|
||||
vol.Required("method"): vol.In(["GET", "POST", "PUT"]),
|
||||
vol.Required("method"): vol.In(SUPPORTED_METHODS),
|
||||
vol.Optional("body", default=""): str,
|
||||
vol.Optional("headers", default={}): {str: str},
|
||||
vol.Optional("query", default=""): str,
|
||||
|
8
homeassistant/components/webhook/strings.json
Normal file
8
homeassistant/components/webhook/strings.json
Normal file
@ -0,0 +1,8 @@
|
||||
{
|
||||
"issues": {
|
||||
"trigger_missing_local_only": {
|
||||
"title": "Update webhook trigger: {webhook_id}",
|
||||
"description": "A choice needs to be made about whether the {webhook_id} webhook automation trigger is accessible from the internet. [Edit the automation]({edit}) \"{automation_name}\", (`{entity_id}`) and click the gear icon beside the Webhook ID to choose a value for 'Only accessible from the local network'"
|
||||
}
|
||||
}
|
||||
}
|
8
homeassistant/components/webhook/translations/en.json
Normal file
8
homeassistant/components/webhook/translations/en.json
Normal file
@ -0,0 +1,8 @@
|
||||
{
|
||||
"issues": {
|
||||
"trigger_missing_local_only": {
|
||||
"description": "A choice needs to be made about whether the {webhook_id} webhook automation trigger is accessible from the internet. Edit the {automation_name} automation, and click the gear icon beside the Webhook ID to choose a value for 'Only accessible from the local network'",
|
||||
"title": "Update webhook trigger: {webhook_id}"
|
||||
}
|
||||
}
|
||||
}
|
@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
from aiohttp import hdrs
|
||||
import voluptuous as vol
|
||||
@ -9,17 +10,39 @@ import voluptuous as vol
|
||||
from homeassistant.const import CONF_PLATFORM, CONF_WEBHOOK_ID
|
||||
from homeassistant.core import CALLBACK_TYPE, HassJob, HomeAssistant, callback
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.issue_registry import (
|
||||
IssueSeverity,
|
||||
async_create_issue,
|
||||
async_delete_issue,
|
||||
)
|
||||
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from . import DOMAIN, async_register, async_unregister
|
||||
from . import (
|
||||
DEFAULT_METHODS,
|
||||
DOMAIN,
|
||||
SUPPORTED_METHODS,
|
||||
async_register,
|
||||
async_unregister,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DEPENDENCIES = ("webhook",)
|
||||
|
||||
CONF_ALLOWED_METHODS = "allowed_methods"
|
||||
CONF_LOCAL_ONLY = "local_only"
|
||||
|
||||
TRIGGER_SCHEMA = cv.TRIGGER_BASE_SCHEMA.extend(
|
||||
{
|
||||
vol.Required(CONF_PLATFORM): "webhook",
|
||||
vol.Required(CONF_WEBHOOK_ID): cv.string,
|
||||
vol.Optional(CONF_ALLOWED_METHODS): vol.All(
|
||||
cv.ensure_list,
|
||||
[vol.All(vol.Upper, vol.In(SUPPORTED_METHODS))],
|
||||
vol.Unique(),
|
||||
),
|
||||
vol.Optional(CONF_LOCAL_ONLY): bool,
|
||||
}
|
||||
)
|
||||
|
||||
@ -62,6 +85,32 @@ async def async_attach_trigger(
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Trigger based on incoming webhooks."""
|
||||
webhook_id: str = config[CONF_WEBHOOK_ID]
|
||||
local_only = config.get(CONF_LOCAL_ONLY)
|
||||
issue_id: str | None = None
|
||||
if local_only is None:
|
||||
issue_id = f"trigger_missing_local_only_{webhook_id}"
|
||||
variables = trigger_info["variables"] or {}
|
||||
automation_info = variables.get("this", {})
|
||||
automation_id = automation_info.get("attributes", {}).get("id")
|
||||
automation_entity_id = automation_info.get("entity_id")
|
||||
automation_name = trigger_info.get("name") or automation_entity_id
|
||||
async_create_issue(
|
||||
hass,
|
||||
DOMAIN,
|
||||
issue_id,
|
||||
breaks_in_ha_version="2023.7.0",
|
||||
is_fixable=False,
|
||||
severity=IssueSeverity.WARNING,
|
||||
learn_more_url="https://www.home-assistant.io/docs/automation/trigger/#webhook-trigger",
|
||||
translation_key="trigger_missing_local_only",
|
||||
translation_placeholders={
|
||||
"webhook_id": webhook_id,
|
||||
"automation_name": automation_name,
|
||||
"entity_id": automation_entity_id,
|
||||
"edit": f"/config/automation/edit/{automation_id}",
|
||||
},
|
||||
)
|
||||
allowed_methods = config.get(CONF_ALLOWED_METHODS, DEFAULT_METHODS)
|
||||
job = HassJob(action)
|
||||
|
||||
triggers: dict[str, list[TriggerInstance]] = hass.data.setdefault(
|
||||
@ -75,6 +124,8 @@ async def async_attach_trigger(
|
||||
trigger_info["name"],
|
||||
webhook_id,
|
||||
_handle_webhook,
|
||||
local_only=local_only,
|
||||
allowed_methods=allowed_methods,
|
||||
)
|
||||
triggers[webhook_id] = []
|
||||
|
||||
@ -84,6 +135,8 @@ async def async_attach_trigger(
|
||||
@callback
|
||||
def unregister():
|
||||
"""Unregister webhook."""
|
||||
if issue_id:
|
||||
async_delete_issue(hass, DOMAIN, issue_id)
|
||||
triggers[webhook_id].remove(trigger_instance)
|
||||
if not triggers[webhook_id]:
|
||||
async_unregister(hass, webhook_id)
|
||||
|
@ -81,6 +81,7 @@ async def fake_post_request_no_data(*args, **kwargs):
|
||||
async def simulate_webhook(hass, webhook_id, response):
|
||||
"""Simulate a webhook event."""
|
||||
request = MockRequest(
|
||||
method="POST",
|
||||
content=bytes(json.dumps({**COMMON_RESPONSE, **response}), "utf-8"),
|
||||
mock_source="test",
|
||||
)
|
||||
|
@ -159,7 +159,9 @@ async def test_webhook_head(hass: HomeAssistant, mock_client) -> None:
|
||||
"""Handle webhook."""
|
||||
hooks.append(args)
|
||||
|
||||
webhook.async_register(hass, "test", "Test hook", webhook_id, handle)
|
||||
webhook.async_register(
|
||||
hass, "test", "Test hook", webhook_id, handle, allowed_methods=["HEAD"]
|
||||
)
|
||||
|
||||
resp = await mock_client.head(f"/api/webhook/{webhook_id}")
|
||||
assert resp.status == HTTPStatus.OK
|
||||
@ -168,6 +170,58 @@ async def test_webhook_head(hass: HomeAssistant, mock_client) -> None:
|
||||
assert hooks[0][1] == webhook_id
|
||||
assert hooks[0][2].method == "HEAD"
|
||||
|
||||
# Test that status is HTTPStatus.OK even when HEAD is not allowed.
|
||||
webhook.async_unregister(hass, webhook_id)
|
||||
webhook.async_register(
|
||||
hass, "test", "Test hook", webhook_id, handle, allowed_methods=["PUT"]
|
||||
)
|
||||
resp = await mock_client.head(f"/api/webhook/{webhook_id}")
|
||||
assert resp.status == HTTPStatus.OK
|
||||
assert len(hooks) == 1 # Should not have been called
|
||||
|
||||
|
||||
async def test_webhook_get(hass, mock_client):
|
||||
"""Test sending a get request to a webhook."""
|
||||
hooks = []
|
||||
webhook_id = webhook.async_generate_id()
|
||||
|
||||
async def handle(*args):
|
||||
"""Handle webhook."""
|
||||
hooks.append(args)
|
||||
|
||||
webhook.async_register(
|
||||
hass, "test", "Test hook", webhook_id, handle, allowed_methods=["GET"]
|
||||
)
|
||||
|
||||
resp = await mock_client.get(f"/api/webhook/{webhook_id}")
|
||||
assert resp.status == HTTPStatus.OK
|
||||
assert len(hooks) == 1
|
||||
assert hooks[0][0] is hass
|
||||
assert hooks[0][1] == webhook_id
|
||||
assert hooks[0][2].method == "GET"
|
||||
|
||||
# Test that status is HTTPStatus.METHOD_NOT_ALLOWED even when GET is not allowed.
|
||||
webhook.async_unregister(hass, webhook_id)
|
||||
webhook.async_register(
|
||||
hass, "test", "Test hook", webhook_id, handle, allowed_methods=["PUT"]
|
||||
)
|
||||
resp = await mock_client.get(f"/api/webhook/{webhook_id}")
|
||||
assert resp.status == HTTPStatus.METHOD_NOT_ALLOWED
|
||||
assert len(hooks) == 1 # Should not have been called
|
||||
|
||||
|
||||
async def test_webhook_not_allowed_method(hass):
|
||||
"""Test that an exception is raised if an unsupported method is used."""
|
||||
webhook_id = webhook.async_generate_id()
|
||||
|
||||
async def handle(*args):
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
webhook.async_register(
|
||||
hass, "test", "Test hook", webhook_id, handle, allowed_methods=["PATCH"]
|
||||
)
|
||||
|
||||
|
||||
async def test_webhook_local_only(hass: HomeAssistant, mock_client) -> None:
|
||||
"""Test posting a webhook with local only."""
|
||||
@ -211,7 +265,15 @@ async def test_listing_webhook(
|
||||
client = await hass_ws_client(hass, hass_access_token)
|
||||
|
||||
webhook.async_register(hass, "test", "Test hook", "my-id", None)
|
||||
webhook.async_register(hass, "test", "Test hook", "my-2", None, local_only=True)
|
||||
webhook.async_register(
|
||||
hass,
|
||||
"test",
|
||||
"Test hook",
|
||||
"my-2",
|
||||
None,
|
||||
local_only=True,
|
||||
allowed_methods=["GET"],
|
||||
)
|
||||
|
||||
await client.send_json({"id": 5, "type": "webhook/list"})
|
||||
|
||||
@ -224,12 +286,14 @@ async def test_listing_webhook(
|
||||
"domain": "test",
|
||||
"name": "Test hook",
|
||||
"local_only": False,
|
||||
"allowed_methods": ["POST", "PUT"],
|
||||
},
|
||||
{
|
||||
"webhook_id": "my-2",
|
||||
"domain": "test",
|
||||
"name": "Test hook",
|
||||
"local_only": True,
|
||||
"allowed_methods": ["GET"],
|
||||
},
|
||||
]
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""The tests for the webhook automation trigger."""
|
||||
from ipaddress import ip_address
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -77,7 +78,11 @@ async def test_webhook_post(
|
||||
"automation",
|
||||
{
|
||||
"automation": {
|
||||
"trigger": {"platform": "webhook", "webhook_id": "post_webhook"},
|
||||
"trigger": {
|
||||
"platform": "webhook",
|
||||
"webhook_id": "post_webhook",
|
||||
"local_only": True,
|
||||
},
|
||||
"action": {
|
||||
"event": "test_success",
|
||||
"event_data_template": {"hello": "yo {{ trigger.data.hello }}"},
|
||||
@ -95,6 +100,64 @@ async def test_webhook_post(
|
||||
assert len(events) == 1
|
||||
assert events[0].data["hello"] == "yo world"
|
||||
|
||||
# Request from remote IP
|
||||
with patch(
|
||||
"homeassistant.components.webhook.ip_address",
|
||||
return_value=ip_address("123.123.123.123"),
|
||||
):
|
||||
await client.post("/api/webhook/post_webhook", data={"hello": "world"})
|
||||
# No hook received
|
||||
await hass.async_block_till_done()
|
||||
assert len(events) == 1
|
||||
|
||||
|
||||
async def test_webhook_allowed_methods_internet(hass, hass_client_no_auth):
|
||||
"""Test the webhook obeys allowed_methods and local_only options."""
|
||||
events = []
|
||||
|
||||
@callback
|
||||
def store_event(event):
|
||||
"""Help store events."""
|
||||
events.append(event)
|
||||
|
||||
hass.bus.async_listen("test_success", store_event)
|
||||
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"automation",
|
||||
{
|
||||
"automation": {
|
||||
"trigger": {
|
||||
"platform": "webhook",
|
||||
"webhook_id": "post_webhook",
|
||||
"allowed_methods": "PUT",
|
||||
# Enable after 2023.4.0: "local_only": False,
|
||||
},
|
||||
"action": {
|
||||
"event": "test_success",
|
||||
"event_data_template": {"hello": "yo {{ trigger.data.hello }}"},
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
client = await hass_client_no_auth()
|
||||
|
||||
await client.post("/api/webhook/post_webhook", data={"hello": "world"})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(events) == 0
|
||||
|
||||
# Request from remote IP
|
||||
with patch(
|
||||
"homeassistant.components.webhook.ip_address",
|
||||
return_value=ip_address("123.123.123.123"),
|
||||
):
|
||||
await client.put("/api/webhook/post_webhook", data={"hello": "world"})
|
||||
await hass.async_block_till_done()
|
||||
assert len(events) == 1
|
||||
|
||||
|
||||
async def test_webhook_query(
|
||||
hass: HomeAssistant, hass_client_no_auth: ClientSessionGenerator
|
||||
|
Loading…
x
Reference in New Issue
Block a user