Tighten validation on calendar create event websocket (#83413)

This commit is contained in:
Allen Porter 2022-12-06 10:04:32 -08:00 committed by GitHub
parent e1923bc13b
commit 4819576b62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 152 additions and 8 deletions

View File

@ -1,10 +1,11 @@
"""Support for Google Calendar event device sensors."""
from __future__ import annotations
from collections.abc import Iterable
from collections.abc import Callable, Iterable
import dataclasses
import datetime
from http import HTTPStatus
from itertools import groupby
import logging
import re
from typing import Any, cast, final
@ -365,17 +366,67 @@ class CalendarListView(http.HomeAssistantView):
return self.json(sorted(calendar_list, key=lambda x: cast(str, x["name"])))
def _has_same_type(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Verify that all values are of the same type."""
def validate(obj: dict[str, Any]) -> dict[str, Any]:
"""Test that all keys in the dict have values of the same type."""
uniq_values = groupby(type(obj[k]) for k in keys)
if len(list(uniq_values)) > 1:
raise vol.Invalid(f"Expected all values to be the same type: {keys}")
return obj
return validate
def _has_consistent_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Verify that all datetime values have a consistent timezone."""
def validate(obj: dict[str, Any]) -> dict[str, Any]:
"""Test that all keys that are datetime values have the same timezone."""
values = [obj[k] for k in keys]
if all(isinstance(value, datetime.datetime) for value in values):
uniq_values = groupby(value.tzinfo for value in values)
if len(list(uniq_values)) > 1:
raise vol.Invalid(
f"Expected all values to have the same timezone: {values}"
)
return obj
return validate
def _is_sorted(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Verify that the specified values are sequential."""
def validate(obj: dict[str, Any]) -> dict[str, Any]:
"""Test that all keys in the dict are in order."""
values = [obj[k] for k in keys]
if values != sorted(values):
raise vol.Invalid(f"Values were not in order: {values}")
return obj
return validate
@websocket_api.websocket_command(
{
vol.Required("type"): "calendar/event/create",
vol.Required("entity_id"): cv.entity_id,
vol.Required(CONF_EVENT): {
vol.Required(EVENT_START): vol.Any(cv.date, cv.datetime),
vol.Required(EVENT_END): vol.Any(cv.date, cv.datetime),
vol.Required(EVENT_SUMMARY): cv.string,
vol.Optional(EVENT_DESCRIPTION): cv.string,
vol.Optional(EVENT_RRULE): _validate_rrule,
},
CONF_EVENT: vol.Schema(
vol.All(
{
vol.Required(EVENT_START): vol.Any(cv.date, cv.datetime),
vol.Required(EVENT_END): vol.Any(cv.date, cv.datetime),
vol.Required(EVENT_SUMMARY): cv.string,
vol.Optional(EVENT_DESCRIPTION): cv.string,
vol.Optional(EVENT_RRULE): _validate_rrule,
},
_has_same_type(EVENT_START, EVENT_END),
_has_consistent_timezone(EVENT_START, EVENT_END),
_is_sorted(EVENT_START, EVENT_END),
)
),
}
)
@websocket_api.async_response

View File

@ -612,3 +612,96 @@ async def test_all_day_iter_order(
events = await get_events("2022-10-06T00:00:00Z", "2022-10-09T00:00:00Z")
assert [event["summary"] for event in events] == event_order
async def test_start_end_types(
ws_client: ClientFixture,
setup_integration: None,
):
"""Test a start and end with different date and date time types."""
client = await ws_client()
result = await client.cmd(
"create",
{
"entity_id": TEST_ENTITY,
"event": {
"summary": "Bastille Day Party",
"dtstart": "1997-07-15",
"dtend": "1997-07-14T17:00:00+00:00",
},
},
)
assert not result.get("success")
assert "error" in result
assert "code" in result.get("error")
assert result["error"]["code"] == "invalid_format"
async def test_end_before_start(
ws_client: ClientFixture,
setup_integration: None,
):
"""Test an event with a start/end date time."""
client = await ws_client()
result = await client.cmd(
"create",
{
"entity_id": TEST_ENTITY,
"event": {
"summary": "Bastille Day Party",
"dtstart": "1997-07-15T04:00:00+00:00",
"dtend": "1997-07-14T17:00:00+00:00",
},
},
)
assert not result.get("success")
assert "error" in result
assert "code" in result.get("error")
assert result["error"]["code"] == "invalid_format"
async def test_invalid_recurrence_rule(
ws_client: ClientFixture,
setup_integration: None,
):
"""Test an event with a recurrence rule."""
client = await ws_client()
result = await client.cmd(
"create",
{
"entity_id": TEST_ENTITY,
"event": {
"summary": "Monday meeting",
"dtstart": "2022-08-29T09:00:00",
"dtend": "2022-08-29T10:00:00",
"rrule": "FREQ=invalid;'",
},
},
)
assert not result.get("success")
assert "error" in result
assert "code" in result.get("error")
assert result["error"]["code"] == "invalid_format"
async def test_invalid_date_formats(
ws_client: ClientFixture, setup_integration: None, get_events: GetEventsFn
):
"""Exercises a validation error within rfc5545 parsing in ical."""
client = await ws_client()
result = await client.cmd(
"create",
{
"entity_id": TEST_ENTITY,
"event": {
"summary": "Bastille Day Party",
# Can't mix offset aware and floating dates
"dtstart": "1997-07-15T04:00:00+08:00",
"dtend": "1997-07-14T17:00:00",
},
},
)
assert not result.get("success")
assert "error" in result
assert "code" in result.get("error")
assert result["error"]["code"] == "invalid_format"