diff --git a/homeassistant/components/calendar/__init__.py b/homeassistant/components/calendar/__init__.py index b5d605b9f6f..c89b36ce636 100644 --- a/homeassistant/components/calendar/__init__.py +++ b/homeassistant/components/calendar/__init__.py @@ -1,10 +1,11 @@ """Support for Google Calendar event device sensors.""" from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Callable, Iterable import dataclasses import datetime from http import HTTPStatus +from itertools import groupby import logging import re from typing import Any, cast, final @@ -365,17 +366,67 @@ class CalendarListView(http.HomeAssistantView): return self.json(sorted(calendar_list, key=lambda x: cast(str, x["name"]))) +def _has_same_type(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]: + """Verify that all values are of the same type.""" + + def validate(obj: dict[str, Any]) -> dict[str, Any]: + """Test that all keys in the dict have values of the same type.""" + uniq_values = groupby(type(obj[k]) for k in keys) + if len(list(uniq_values)) > 1: + raise vol.Invalid(f"Expected all values to be the same type: {keys}") + return obj + + return validate + + +def _has_consistent_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]: + """Verify that all datetime values have a consistent timezone.""" + + def validate(obj: dict[str, Any]) -> dict[str, Any]: + """Test that all keys that are datetime values have the same timezone.""" + values = [obj[k] for k in keys] + if all(isinstance(value, datetime.datetime) for value in values): + uniq_values = groupby(value.tzinfo for value in values) + if len(list(uniq_values)) > 1: + raise vol.Invalid( + f"Expected all values to have the same timezone: {values}" + ) + return obj + + return validate + + +def _is_sorted(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]: + """Verify that the specified values are sequential.""" + + def validate(obj: dict[str, Any]) -> dict[str, Any]: + """Test that all keys in the dict are in order.""" + values = [obj[k] for k in keys] + if values != sorted(values): + raise vol.Invalid(f"Values were not in order: {values}") + return obj + + return validate + + @websocket_api.websocket_command( { vol.Required("type"): "calendar/event/create", vol.Required("entity_id"): cv.entity_id, - vol.Required(CONF_EVENT): { - vol.Required(EVENT_START): vol.Any(cv.date, cv.datetime), - vol.Required(EVENT_END): vol.Any(cv.date, cv.datetime), - vol.Required(EVENT_SUMMARY): cv.string, - vol.Optional(EVENT_DESCRIPTION): cv.string, - vol.Optional(EVENT_RRULE): _validate_rrule, - }, + CONF_EVENT: vol.Schema( + vol.All( + { + vol.Required(EVENT_START): vol.Any(cv.date, cv.datetime), + vol.Required(EVENT_END): vol.Any(cv.date, cv.datetime), + vol.Required(EVENT_SUMMARY): cv.string, + vol.Optional(EVENT_DESCRIPTION): cv.string, + vol.Optional(EVENT_RRULE): _validate_rrule, + }, + _has_same_type(EVENT_START, EVENT_END), + _has_consistent_timezone(EVENT_START, EVENT_END), + _is_sorted(EVENT_START, EVENT_END), + ) + ), } ) @websocket_api.async_response diff --git a/tests/components/local_calendar/test_calendar.py b/tests/components/local_calendar/test_calendar.py index 77632c8bfe1..092fcb1c1fb 100644 --- a/tests/components/local_calendar/test_calendar.py +++ b/tests/components/local_calendar/test_calendar.py @@ -612,3 +612,96 @@ async def test_all_day_iter_order( events = await get_events("2022-10-06T00:00:00Z", "2022-10-09T00:00:00Z") assert [event["summary"] for event in events] == event_order + + +async def test_start_end_types( + ws_client: ClientFixture, + setup_integration: None, +): + """Test a start and end with different date and date time types.""" + client = await ws_client() + result = await client.cmd( + "create", + { + "entity_id": TEST_ENTITY, + "event": { + "summary": "Bastille Day Party", + "dtstart": "1997-07-15", + "dtend": "1997-07-14T17:00:00+00:00", + }, + }, + ) + assert not result.get("success") + assert "error" in result + assert "code" in result.get("error") + assert result["error"]["code"] == "invalid_format" + + +async def test_end_before_start( + ws_client: ClientFixture, + setup_integration: None, +): + """Test an event with a start/end date time.""" + client = await ws_client() + result = await client.cmd( + "create", + { + "entity_id": TEST_ENTITY, + "event": { + "summary": "Bastille Day Party", + "dtstart": "1997-07-15T04:00:00+00:00", + "dtend": "1997-07-14T17:00:00+00:00", + }, + }, + ) + assert not result.get("success") + assert "error" in result + assert "code" in result.get("error") + assert result["error"]["code"] == "invalid_format" + + +async def test_invalid_recurrence_rule( + ws_client: ClientFixture, + setup_integration: None, +): + """Test an event with a recurrence rule.""" + client = await ws_client() + result = await client.cmd( + "create", + { + "entity_id": TEST_ENTITY, + "event": { + "summary": "Monday meeting", + "dtstart": "2022-08-29T09:00:00", + "dtend": "2022-08-29T10:00:00", + "rrule": "FREQ=invalid;'", + }, + }, + ) + assert not result.get("success") + assert "error" in result + assert "code" in result.get("error") + assert result["error"]["code"] == "invalid_format" + + +async def test_invalid_date_formats( + ws_client: ClientFixture, setup_integration: None, get_events: GetEventsFn +): + """Exercises a validation error within rfc5545 parsing in ical.""" + client = await ws_client() + result = await client.cmd( + "create", + { + "entity_id": TEST_ENTITY, + "event": { + "summary": "Bastille Day Party", + # Can't mix offset aware and floating dates + "dtstart": "1997-07-15T04:00:00+08:00", + "dtend": "1997-07-14T17:00:00", + }, + }, + ) + assert not result.get("success") + assert "error" in result + assert "code" in result.get("error") + assert result["error"]["code"] == "invalid_format"