Add additional CalendarEvent validation (#89533)

Add additional event validation
This commit is contained in:
Allen Porter 2023-03-14 17:27:38 -07:00 committed by GitHub
parent c33ca4f664
commit 4ddcb14053
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 133 additions and 82 deletions

View File

@ -356,4 +356,10 @@ class WebDavCalendarData:
else: else:
enddate = obj.dtstart.value + timedelta(days=1) enddate = obj.dtstart.value + timedelta(days=1)
# End date for an all day event is exclusive. This fixes the case where
# an all day event has a start and end values are the same, or the event
# has a zero duration.
if not isinstance(enddate, datetime) and obj.dtstart.value == enddate:
enddate += timedelta(days=1)
return enddate return enddate

View File

@ -67,6 +67,23 @@ SCAN_INTERVAL = datetime.timedelta(seconds=60)
VALID_FREQS = {"DAILY", "WEEKLY", "MONTHLY", "YEARLY"} VALID_FREQS = {"DAILY", "WEEKLY", "MONTHLY", "YEARLY"}
def _has_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Assert that all datetime values have a timezone."""
def validate(obj: dict[str, Any]) -> dict[str, Any]:
"""Validate that all datetime values have a timezone."""
for k in keys:
if (
(value := obj.get(k))
and isinstance(value, datetime.datetime)
and value.tzinfo is None
):
raise vol.Invalid("Expected all values to have a timezone")
return obj
return validate
def _has_consistent_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]: def _has_consistent_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Verify that all datetime values have a consistent timezone.""" """Verify that all datetime values have a consistent timezone."""
@ -89,7 +106,7 @@ def _as_local_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]
"""Convert all datetime values to the local timezone.""" """Convert all datetime values to the local timezone."""
def validate(obj: dict[str, Any]) -> dict[str, Any]: def validate(obj: dict[str, Any]) -> dict[str, Any]:
"""Test that all keys that are datetime values have the same timezone.""" """Convert all keys that are datetime values to local timezone."""
for k in keys: for k in keys:
if (value := obj.get(k)) and isinstance(value, datetime.datetime): if (value := obj.get(k)) and isinstance(value, datetime.datetime):
obj[k] = dt.as_local(value) obj[k] = dt.as_local(value)
@ -98,23 +115,59 @@ def _as_local_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]
return validate return validate
def _is_sorted(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]: def _has_duration(
"""Verify that the specified values are sequential.""" start_key: str, end_key: str
) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Verify that the time span between start and end is positive."""
def validate(obj: dict[str, Any]) -> dict[str, Any]: def validate(obj: dict[str, Any]) -> dict[str, Any]:
"""Test that all keys in the dict are in order.""" """Test that all keys in the dict are in order."""
values = [] if (start := obj.get(start_key)) and (end := obj.get(end_key)):
for k in keys: duration = end - start
if not (value := obj.get(k)): if duration.total_seconds() <= 0:
return obj raise vol.Invalid(f"Expected positive event duration ({start}, {end})")
values.append(value)
if all(values) and values != sorted(values):
raise vol.Invalid(f"Values were not in order: {values}")
return obj return obj
return validate return validate
def _has_same_type(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Verify that all values are of the same type."""
def validate(obj: dict[str, Any]) -> dict[str, Any]:
"""Test that all keys in the dict have values of the same type."""
uniq_values = groupby(type(obj[k]) for k in keys)
if len(list(uniq_values)) > 1:
raise vol.Invalid(f"Expected all values to be the same type: {keys}")
return obj
return validate
def _validate_rrule(value: Any) -> str:
"""Validate a recurrence rule string."""
if value is None:
raise vol.Invalid("rrule value is None")
if not isinstance(value, str):
raise vol.Invalid("rrule value expected a string")
try:
rrulestr(value)
except ValueError as err:
raise vol.Invalid(f"Invalid rrule: {str(err)}") from err
# Example format: FREQ=DAILY;UNTIL=...
rule_parts = dict(s.split("=", 1) for s in value.split(";"))
if not (freq := rule_parts.get("FREQ")):
raise vol.Invalid("rrule did not contain FREQ")
if freq not in VALID_FREQS:
raise vol.Invalid(f"Invalid frequency for rule: {value}")
return str(value)
CREATE_EVENT_SERVICE = "create_event" CREATE_EVENT_SERVICE = "create_event"
CREATE_EVENT_SCHEMA = vol.All( CREATE_EVENT_SCHEMA = vol.All(
cv.has_at_least_one_key(EVENT_START_DATE, EVENT_START_DATETIME, EVENT_IN), cv.has_at_least_one_key(EVENT_START_DATE, EVENT_START_DATETIME, EVENT_IN),
@ -149,8 +202,42 @@ CREATE_EVENT_SCHEMA = vol.All(
), ),
_has_consistent_timezone(EVENT_START_DATETIME, EVENT_END_DATETIME), _has_consistent_timezone(EVENT_START_DATETIME, EVENT_END_DATETIME),
_as_local_timezone(EVENT_START_DATETIME, EVENT_END_DATETIME), _as_local_timezone(EVENT_START_DATETIME, EVENT_END_DATETIME),
_is_sorted(EVENT_START_DATE, EVENT_END_DATE), _has_duration(EVENT_START_DATE, EVENT_END_DATE),
_is_sorted(EVENT_START_DATETIME, EVENT_END_DATETIME), _has_duration(EVENT_START_DATETIME, EVENT_END_DATETIME),
)
WEBSOCKET_EVENT_SCHEMA = vol.Schema(
vol.All(
{
vol.Required(EVENT_START): vol.Any(cv.date, cv.datetime),
vol.Required(EVENT_END): vol.Any(cv.date, cv.datetime),
vol.Required(EVENT_SUMMARY): cv.string,
vol.Optional(EVENT_DESCRIPTION): cv.string,
vol.Optional(EVENT_RRULE): _validate_rrule,
},
_has_same_type(EVENT_START, EVENT_END),
_has_consistent_timezone(EVENT_START, EVENT_END),
_as_local_timezone(EVENT_START, EVENT_END),
_has_duration(EVENT_START, EVENT_END),
)
)
# Validation for the CalendarEvent dataclass
CALENDAR_EVENT_SCHEMA = vol.Schema(
vol.All(
{
vol.Required("start"): vol.Any(cv.date, cv.datetime),
vol.Required("end"): vol.Any(cv.date, cv.datetime),
vol.Required(EVENT_SUMMARY): cv.string,
vol.Optional(EVENT_RRULE): _validate_rrule,
},
_has_same_type("start", "end"),
_has_timezone("start", "end"),
_has_consistent_timezone("start", "end"),
_as_local_timezone("start", "end"),
_has_duration("start", "end"),
),
extra=vol.ALLOW_EXTRA,
) )
@ -243,6 +330,19 @@ class CalendarEvent:
"all_day": self.all_day, "all_day": self.all_day,
} }
def __post_init__(self) -> None:
"""Perform validation on the CalendarEvent."""
def skip_none(obj: Iterable[tuple[str, Any]]) -> dict[str, str]:
return {k: v for k, v in obj if v is not None}
try:
CALENDAR_EVENT_SCHEMA(dataclasses.asdict(self, dict_factory=skip_none))
except vol.Invalid as err:
raise HomeAssistantError(
f"Failed to validate CalendarEvent: {err}"
) from err
def _event_dict_factory(obj: Iterable[tuple[str, Any]]) -> dict[str, str]: def _event_dict_factory(obj: Iterable[tuple[str, Any]]) -> dict[str, str]:
"""Convert CalendarEvent dataclass items to dictionary of attributes.""" """Convert CalendarEvent dataclass items to dictionary of attributes."""
@ -316,30 +416,6 @@ def is_offset_reached(
return start + offset_time <= dt.now(start.tzinfo) return start + offset_time <= dt.now(start.tzinfo)
def _validate_rrule(value: Any) -> str:
"""Validate a recurrence rule string."""
if value is None:
raise vol.Invalid("rrule value is None")
if not isinstance(value, str):
raise vol.Invalid("rrule value expected a string")
try:
rrulestr(value)
except ValueError as err:
raise vol.Invalid(f"Invalid rrule: {str(err)}") from err
# Example format: FREQ=DAILY;UNTIL=...
rule_parts = dict(s.split("=", 1) for s in value.split(";"))
if not (freq := rule_parts.get("FREQ")):
raise vol.Invalid("rrule did not contain FREQ")
if freq not in VALID_FREQS:
raise vol.Invalid(f"Invalid frequency for rule: {value}")
return str(value)
class CalendarEntity(Entity): class CalendarEntity(Entity):
"""Base class for calendar event entities.""" """Base class for calendar event entities."""
@ -447,6 +523,7 @@ class CalendarEventView(http.HomeAssistantView):
request.app["hass"], start_date, end_date request.app["hass"], start_date, end_date
) )
except HomeAssistantError as err: except HomeAssistantError as err:
_LOGGER.debug("Error reading events: %s", err)
return self.json_message( return self.json_message(
f"Error reading events: {err}", HTTPStatus.INTERNAL_SERVER_ERROR f"Error reading events: {err}", HTTPStatus.INTERNAL_SERVER_ERROR
) )
@ -481,38 +558,11 @@ class CalendarListView(http.HomeAssistantView):
return self.json(sorted(calendar_list, key=lambda x: cast(str, x["name"]))) return self.json(sorted(calendar_list, key=lambda x: cast(str, x["name"])))
def _has_same_type(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Verify that all values are of the same type."""
def validate(obj: dict[str, Any]) -> dict[str, Any]:
"""Test that all keys in the dict have values of the same type."""
uniq_values = groupby(type(obj[k]) for k in keys)
if len(list(uniq_values)) > 1:
raise vol.Invalid(f"Expected all values to be the same type: {keys}")
return obj
return validate
@websocket_api.websocket_command( @websocket_api.websocket_command(
{ {
vol.Required("type"): "calendar/event/create", vol.Required("type"): "calendar/event/create",
vol.Required("entity_id"): cv.entity_id, vol.Required("entity_id"): cv.entity_id,
CONF_EVENT: vol.Schema( CONF_EVENT: WEBSOCKET_EVENT_SCHEMA,
vol.All(
{
vol.Required(EVENT_START): vol.Any(cv.date, cv.datetime),
vol.Required(EVENT_END): vol.Any(cv.date, cv.datetime),
vol.Required(EVENT_SUMMARY): cv.string,
vol.Optional(EVENT_DESCRIPTION): cv.string,
vol.Optional(EVENT_RRULE): _validate_rrule,
},
_has_same_type(EVENT_START, EVENT_END),
_has_consistent_timezone(EVENT_START, EVENT_END),
_as_local_timezone(EVENT_START, EVENT_END),
_is_sorted(EVENT_START, EVENT_END),
)
),
} }
) )
@websocket_api.async_response @websocket_api.async_response
@ -595,21 +645,7 @@ async def handle_calendar_event_delete(
vol.Required(EVENT_UID): cv.string, vol.Required(EVENT_UID): cv.string,
vol.Optional(EVENT_RECURRENCE_ID): cv.string, vol.Optional(EVENT_RECURRENCE_ID): cv.string,
vol.Optional(EVENT_RECURRENCE_RANGE): cv.string, vol.Optional(EVENT_RECURRENCE_RANGE): cv.string,
vol.Required(CONF_EVENT): vol.Schema( vol.Required(CONF_EVENT): WEBSOCKET_EVENT_SCHEMA,
vol.All(
{
vol.Required(EVENT_START): vol.Any(cv.date, cv.datetime),
vol.Required(EVENT_END): vol.Any(cv.date, cv.datetime),
vol.Required(EVENT_SUMMARY): cv.string,
vol.Optional(EVENT_DESCRIPTION): cv.string,
vol.Optional(EVENT_RRULE): _validate_rrule,
},
_has_same_type(EVENT_START, EVENT_END),
_has_consistent_timezone(EVENT_START, EVENT_END),
_as_local_timezone(EVENT_START, EVENT_END),
_is_sorted(EVENT_START, EVENT_END),
)
),
} }
) )
@websocket_api.async_response @websocket_api.async_response

View File

@ -324,7 +324,7 @@ async def test_unsupported_create_event_service(hass: HomeAssistant) -> None:
"end_date_time": "2022-04-01T06:00:00", "end_date_time": "2022-04-01T06:00:00",
}, },
vol.error.MultipleInvalid, vol.error.MultipleInvalid,
"Values were not in order", "Expected positive event duration",
), ),
( (
{ {
@ -332,7 +332,15 @@ async def test_unsupported_create_event_service(hass: HomeAssistant) -> None:
"end_date": "2022-04-01", "end_date": "2022-04-01",
}, },
vol.error.MultipleInvalid, vol.error.MultipleInvalid,
"Values were not in order", "Expected positive event duration",
),
(
{
"start_date": "2022-04-01",
"end_date": "2022-04-01",
},
vol.error.MultipleInvalid,
"Expected positive event duration",
), ),
], ],
ids=[ ids=[
@ -351,6 +359,7 @@ async def test_unsupported_create_event_service(hass: HomeAssistant) -> None:
"inconsistent_timezone", "inconsistent_timezone",
"incorrect_date_order", "incorrect_date_order",
"incorrect_datetime_order", "incorrect_datetime_order",
"dates_not_exclusive",
], ],
) )
async def test_create_event_service_invalid_params( async def test_create_event_service_invalid_params(

View File

@ -597,7 +597,7 @@ async def test_add_event_failure(
with pytest.raises(HomeAssistantError): with pytest.raises(HomeAssistantError):
await add_event_call_service( await add_event_call_service(
{"start_date": "2022-05-01", "end_date": "2022-05-01"} {"start_date": "2022-05-01", "end_date": "2022-05-02"}
) )