From 4ddcb140532f5771a2420a3ef52d066daede2fe5 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Tue, 14 Mar 2023 17:27:38 -0700 Subject: [PATCH] Add additional CalendarEvent validation (#89533) Add additional event validation --- homeassistant/components/caldav/calendar.py | 6 + homeassistant/components/calendar/__init__.py | 194 +++++++++++------- tests/components/calendar/test_init.py | 13 +- tests/components/google/test_init.py | 2 +- 4 files changed, 133 insertions(+), 82 deletions(-) diff --git a/homeassistant/components/caldav/calendar.py b/homeassistant/components/caldav/calendar.py index ab3c47b9690..9a01cd2186f 100644 --- a/homeassistant/components/caldav/calendar.py +++ b/homeassistant/components/caldav/calendar.py @@ -356,4 +356,10 @@ class WebDavCalendarData: else: enddate = obj.dtstart.value + timedelta(days=1) + # End date for an all day event is exclusive. This fixes the case where + # an all day event has a start and end values are the same, or the event + # has a zero duration. + if not isinstance(enddate, datetime) and obj.dtstart.value == enddate: + enddate += timedelta(days=1) + return enddate diff --git a/homeassistant/components/calendar/__init__.py b/homeassistant/components/calendar/__init__.py index c77d6c9c67a..d09a389ce82 100644 --- a/homeassistant/components/calendar/__init__.py +++ b/homeassistant/components/calendar/__init__.py @@ -67,6 +67,23 @@ SCAN_INTERVAL = datetime.timedelta(seconds=60) VALID_FREQS = {"DAILY", "WEEKLY", "MONTHLY", "YEARLY"} +def _has_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]: + """Assert that all datetime values have a timezone.""" + + def validate(obj: dict[str, Any]) -> dict[str, Any]: + """Validate that all datetime values have a timezone.""" + for k in keys: + if ( + (value := obj.get(k)) + and isinstance(value, datetime.datetime) + and value.tzinfo is None + ): + raise vol.Invalid("Expected all values to have a timezone") + return obj + + return validate + + def _has_consistent_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]: """Verify that all datetime values have a consistent timezone.""" @@ -89,7 +106,7 @@ def _as_local_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]] """Convert all datetime values to the local timezone.""" def validate(obj: dict[str, Any]) -> dict[str, Any]: - """Test that all keys that are datetime values have the same timezone.""" + """Convert all keys that are datetime values to local timezone.""" for k in keys: if (value := obj.get(k)) and isinstance(value, datetime.datetime): obj[k] = dt.as_local(value) @@ -98,23 +115,59 @@ def _as_local_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]] return validate -def _is_sorted(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]: - """Verify that the specified values are sequential.""" +def _has_duration( + start_key: str, end_key: str +) -> Callable[[dict[str, Any]], dict[str, Any]]: + """Verify that the time span between start and end is positive.""" def validate(obj: dict[str, Any]) -> dict[str, Any]: """Test that all keys in the dict are in order.""" - values = [] - for k in keys: - if not (value := obj.get(k)): - return obj - values.append(value) - if all(values) and values != sorted(values): - raise vol.Invalid(f"Values were not in order: {values}") + if (start := obj.get(start_key)) and (end := obj.get(end_key)): + duration = end - start + if duration.total_seconds() <= 0: + raise vol.Invalid(f"Expected positive event duration ({start}, {end})") return obj return validate +def _has_same_type(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]: + """Verify that all values are of the same type.""" + + def validate(obj: dict[str, Any]) -> dict[str, Any]: + """Test that all keys in the dict have values of the same type.""" + uniq_values = groupby(type(obj[k]) for k in keys) + if len(list(uniq_values)) > 1: + raise vol.Invalid(f"Expected all values to be the same type: {keys}") + return obj + + return validate + + +def _validate_rrule(value: Any) -> str: + """Validate a recurrence rule string.""" + if value is None: + raise vol.Invalid("rrule value is None") + + if not isinstance(value, str): + raise vol.Invalid("rrule value expected a string") + + try: + rrulestr(value) + except ValueError as err: + raise vol.Invalid(f"Invalid rrule: {str(err)}") from err + + # Example format: FREQ=DAILY;UNTIL=... + rule_parts = dict(s.split("=", 1) for s in value.split(";")) + if not (freq := rule_parts.get("FREQ")): + raise vol.Invalid("rrule did not contain FREQ") + + if freq not in VALID_FREQS: + raise vol.Invalid(f"Invalid frequency for rule: {value}") + + return str(value) + + CREATE_EVENT_SERVICE = "create_event" CREATE_EVENT_SCHEMA = vol.All( cv.has_at_least_one_key(EVENT_START_DATE, EVENT_START_DATETIME, EVENT_IN), @@ -149,8 +202,42 @@ CREATE_EVENT_SCHEMA = vol.All( ), _has_consistent_timezone(EVENT_START_DATETIME, EVENT_END_DATETIME), _as_local_timezone(EVENT_START_DATETIME, EVENT_END_DATETIME), - _is_sorted(EVENT_START_DATE, EVENT_END_DATE), - _is_sorted(EVENT_START_DATETIME, EVENT_END_DATETIME), + _has_duration(EVENT_START_DATE, EVENT_END_DATE), + _has_duration(EVENT_START_DATETIME, EVENT_END_DATETIME), +) + +WEBSOCKET_EVENT_SCHEMA = vol.Schema( + vol.All( + { + vol.Required(EVENT_START): vol.Any(cv.date, cv.datetime), + vol.Required(EVENT_END): vol.Any(cv.date, cv.datetime), + vol.Required(EVENT_SUMMARY): cv.string, + vol.Optional(EVENT_DESCRIPTION): cv.string, + vol.Optional(EVENT_RRULE): _validate_rrule, + }, + _has_same_type(EVENT_START, EVENT_END), + _has_consistent_timezone(EVENT_START, EVENT_END), + _as_local_timezone(EVENT_START, EVENT_END), + _has_duration(EVENT_START, EVENT_END), + ) +) + +# Validation for the CalendarEvent dataclass +CALENDAR_EVENT_SCHEMA = vol.Schema( + vol.All( + { + vol.Required("start"): vol.Any(cv.date, cv.datetime), + vol.Required("end"): vol.Any(cv.date, cv.datetime), + vol.Required(EVENT_SUMMARY): cv.string, + vol.Optional(EVENT_RRULE): _validate_rrule, + }, + _has_same_type("start", "end"), + _has_timezone("start", "end"), + _has_consistent_timezone("start", "end"), + _as_local_timezone("start", "end"), + _has_duration("start", "end"), + ), + extra=vol.ALLOW_EXTRA, ) @@ -243,6 +330,19 @@ class CalendarEvent: "all_day": self.all_day, } + def __post_init__(self) -> None: + """Perform validation on the CalendarEvent.""" + + def skip_none(obj: Iterable[tuple[str, Any]]) -> dict[str, str]: + return {k: v for k, v in obj if v is not None} + + try: + CALENDAR_EVENT_SCHEMA(dataclasses.asdict(self, dict_factory=skip_none)) + except vol.Invalid as err: + raise HomeAssistantError( + f"Failed to validate CalendarEvent: {err}" + ) from err + def _event_dict_factory(obj: Iterable[tuple[str, Any]]) -> dict[str, str]: """Convert CalendarEvent dataclass items to dictionary of attributes.""" @@ -316,30 +416,6 @@ def is_offset_reached( return start + offset_time <= dt.now(start.tzinfo) -def _validate_rrule(value: Any) -> str: - """Validate a recurrence rule string.""" - if value is None: - raise vol.Invalid("rrule value is None") - - if not isinstance(value, str): - raise vol.Invalid("rrule value expected a string") - - try: - rrulestr(value) - except ValueError as err: - raise vol.Invalid(f"Invalid rrule: {str(err)}") from err - - # Example format: FREQ=DAILY;UNTIL=... - rule_parts = dict(s.split("=", 1) for s in value.split(";")) - if not (freq := rule_parts.get("FREQ")): - raise vol.Invalid("rrule did not contain FREQ") - - if freq not in VALID_FREQS: - raise vol.Invalid(f"Invalid frequency for rule: {value}") - - return str(value) - - class CalendarEntity(Entity): """Base class for calendar event entities.""" @@ -447,6 +523,7 @@ class CalendarEventView(http.HomeAssistantView): request.app["hass"], start_date, end_date ) except HomeAssistantError as err: + _LOGGER.debug("Error reading events: %s", err) return self.json_message( f"Error reading events: {err}", HTTPStatus.INTERNAL_SERVER_ERROR ) @@ -481,38 +558,11 @@ class CalendarListView(http.HomeAssistantView): return self.json(sorted(calendar_list, key=lambda x: cast(str, x["name"]))) -def _has_same_type(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]: - """Verify that all values are of the same type.""" - - def validate(obj: dict[str, Any]) -> dict[str, Any]: - """Test that all keys in the dict have values of the same type.""" - uniq_values = groupby(type(obj[k]) for k in keys) - if len(list(uniq_values)) > 1: - raise vol.Invalid(f"Expected all values to be the same type: {keys}") - return obj - - return validate - - @websocket_api.websocket_command( { vol.Required("type"): "calendar/event/create", vol.Required("entity_id"): cv.entity_id, - CONF_EVENT: vol.Schema( - vol.All( - { - vol.Required(EVENT_START): vol.Any(cv.date, cv.datetime), - vol.Required(EVENT_END): vol.Any(cv.date, cv.datetime), - vol.Required(EVENT_SUMMARY): cv.string, - vol.Optional(EVENT_DESCRIPTION): cv.string, - vol.Optional(EVENT_RRULE): _validate_rrule, - }, - _has_same_type(EVENT_START, EVENT_END), - _has_consistent_timezone(EVENT_START, EVENT_END), - _as_local_timezone(EVENT_START, EVENT_END), - _is_sorted(EVENT_START, EVENT_END), - ) - ), + CONF_EVENT: WEBSOCKET_EVENT_SCHEMA, } ) @websocket_api.async_response @@ -595,21 +645,7 @@ async def handle_calendar_event_delete( vol.Required(EVENT_UID): cv.string, vol.Optional(EVENT_RECURRENCE_ID): cv.string, vol.Optional(EVENT_RECURRENCE_RANGE): cv.string, - vol.Required(CONF_EVENT): vol.Schema( - vol.All( - { - vol.Required(EVENT_START): vol.Any(cv.date, cv.datetime), - vol.Required(EVENT_END): vol.Any(cv.date, cv.datetime), - vol.Required(EVENT_SUMMARY): cv.string, - vol.Optional(EVENT_DESCRIPTION): cv.string, - vol.Optional(EVENT_RRULE): _validate_rrule, - }, - _has_same_type(EVENT_START, EVENT_END), - _has_consistent_timezone(EVENT_START, EVENT_END), - _as_local_timezone(EVENT_START, EVENT_END), - _is_sorted(EVENT_START, EVENT_END), - ) - ), + vol.Required(CONF_EVENT): WEBSOCKET_EVENT_SCHEMA, } ) @websocket_api.async_response diff --git a/tests/components/calendar/test_init.py b/tests/components/calendar/test_init.py index 5c90a1cfc2c..875d5bf8c13 100644 --- a/tests/components/calendar/test_init.py +++ b/tests/components/calendar/test_init.py @@ -324,7 +324,7 @@ async def test_unsupported_create_event_service(hass: HomeAssistant) -> None: "end_date_time": "2022-04-01T06:00:00", }, vol.error.MultipleInvalid, - "Values were not in order", + "Expected positive event duration", ), ( { @@ -332,7 +332,15 @@ async def test_unsupported_create_event_service(hass: HomeAssistant) -> None: "end_date": "2022-04-01", }, vol.error.MultipleInvalid, - "Values were not in order", + "Expected positive event duration", + ), + ( + { + "start_date": "2022-04-01", + "end_date": "2022-04-01", + }, + vol.error.MultipleInvalid, + "Expected positive event duration", ), ], ids=[ @@ -351,6 +359,7 @@ async def test_unsupported_create_event_service(hass: HomeAssistant) -> None: "inconsistent_timezone", "incorrect_date_order", "incorrect_datetime_order", + "dates_not_exclusive", ], ) async def test_create_event_service_invalid_params( diff --git a/tests/components/google/test_init.py b/tests/components/google/test_init.py index 28525acd468..eac3bff5854 100644 --- a/tests/components/google/test_init.py +++ b/tests/components/google/test_init.py @@ -597,7 +597,7 @@ async def test_add_event_failure( with pytest.raises(HomeAssistantError): await add_event_call_service( - {"start_date": "2022-05-01", "end_date": "2022-05-01"} + {"start_date": "2022-05-01", "end_date": "2022-05-02"} )