mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
Add additional CalendarEvent validation (#89533)
Add additional event validation
This commit is contained in:
parent
c33ca4f664
commit
4ddcb14053
@ -356,4 +356,10 @@ class WebDavCalendarData:
|
||||
else:
|
||||
enddate = obj.dtstart.value + timedelta(days=1)
|
||||
|
||||
# End date for an all day event is exclusive. This fixes the case where
|
||||
# an all day event has a start and end values are the same, or the event
|
||||
# has a zero duration.
|
||||
if not isinstance(enddate, datetime) and obj.dtstart.value == enddate:
|
||||
enddate += timedelta(days=1)
|
||||
|
||||
return enddate
|
||||
|
@ -67,6 +67,23 @@ SCAN_INTERVAL = datetime.timedelta(seconds=60)
|
||||
VALID_FREQS = {"DAILY", "WEEKLY", "MONTHLY", "YEARLY"}
|
||||
|
||||
|
||||
def _has_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
|
||||
"""Assert that all datetime values have a timezone."""
|
||||
|
||||
def validate(obj: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate that all datetime values have a timezone."""
|
||||
for k in keys:
|
||||
if (
|
||||
(value := obj.get(k))
|
||||
and isinstance(value, datetime.datetime)
|
||||
and value.tzinfo is None
|
||||
):
|
||||
raise vol.Invalid("Expected all values to have a timezone")
|
||||
return obj
|
||||
|
||||
return validate
|
||||
|
||||
|
||||
def _has_consistent_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
|
||||
"""Verify that all datetime values have a consistent timezone."""
|
||||
|
||||
@ -89,7 +106,7 @@ def _as_local_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]
|
||||
"""Convert all datetime values to the local timezone."""
|
||||
|
||||
def validate(obj: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Test that all keys that are datetime values have the same timezone."""
|
||||
"""Convert all keys that are datetime values to local timezone."""
|
||||
for k in keys:
|
||||
if (value := obj.get(k)) and isinstance(value, datetime.datetime):
|
||||
obj[k] = dt.as_local(value)
|
||||
@ -98,23 +115,59 @@ def _as_local_timezone(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]
|
||||
return validate
|
||||
|
||||
|
||||
def _is_sorted(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
|
||||
"""Verify that the specified values are sequential."""
|
||||
def _has_duration(
|
||||
start_key: str, end_key: str
|
||||
) -> Callable[[dict[str, Any]], dict[str, Any]]:
|
||||
"""Verify that the time span between start and end is positive."""
|
||||
|
||||
def validate(obj: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Test that all keys in the dict are in order."""
|
||||
values = []
|
||||
for k in keys:
|
||||
if not (value := obj.get(k)):
|
||||
return obj
|
||||
values.append(value)
|
||||
if all(values) and values != sorted(values):
|
||||
raise vol.Invalid(f"Values were not in order: {values}")
|
||||
if (start := obj.get(start_key)) and (end := obj.get(end_key)):
|
||||
duration = end - start
|
||||
if duration.total_seconds() <= 0:
|
||||
raise vol.Invalid(f"Expected positive event duration ({start}, {end})")
|
||||
return obj
|
||||
|
||||
return validate
|
||||
|
||||
|
||||
def _has_same_type(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
|
||||
"""Verify that all values are of the same type."""
|
||||
|
||||
def validate(obj: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Test that all keys in the dict have values of the same type."""
|
||||
uniq_values = groupby(type(obj[k]) for k in keys)
|
||||
if len(list(uniq_values)) > 1:
|
||||
raise vol.Invalid(f"Expected all values to be the same type: {keys}")
|
||||
return obj
|
||||
|
||||
return validate
|
||||
|
||||
|
||||
def _validate_rrule(value: Any) -> str:
|
||||
"""Validate a recurrence rule string."""
|
||||
if value is None:
|
||||
raise vol.Invalid("rrule value is None")
|
||||
|
||||
if not isinstance(value, str):
|
||||
raise vol.Invalid("rrule value expected a string")
|
||||
|
||||
try:
|
||||
rrulestr(value)
|
||||
except ValueError as err:
|
||||
raise vol.Invalid(f"Invalid rrule: {str(err)}") from err
|
||||
|
||||
# Example format: FREQ=DAILY;UNTIL=...
|
||||
rule_parts = dict(s.split("=", 1) for s in value.split(";"))
|
||||
if not (freq := rule_parts.get("FREQ")):
|
||||
raise vol.Invalid("rrule did not contain FREQ")
|
||||
|
||||
if freq not in VALID_FREQS:
|
||||
raise vol.Invalid(f"Invalid frequency for rule: {value}")
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
CREATE_EVENT_SERVICE = "create_event"
|
||||
CREATE_EVENT_SCHEMA = vol.All(
|
||||
cv.has_at_least_one_key(EVENT_START_DATE, EVENT_START_DATETIME, EVENT_IN),
|
||||
@ -149,8 +202,42 @@ CREATE_EVENT_SCHEMA = vol.All(
|
||||
),
|
||||
_has_consistent_timezone(EVENT_START_DATETIME, EVENT_END_DATETIME),
|
||||
_as_local_timezone(EVENT_START_DATETIME, EVENT_END_DATETIME),
|
||||
_is_sorted(EVENT_START_DATE, EVENT_END_DATE),
|
||||
_is_sorted(EVENT_START_DATETIME, EVENT_END_DATETIME),
|
||||
_has_duration(EVENT_START_DATE, EVENT_END_DATE),
|
||||
_has_duration(EVENT_START_DATETIME, EVENT_END_DATETIME),
|
||||
)
|
||||
|
||||
WEBSOCKET_EVENT_SCHEMA = vol.Schema(
|
||||
vol.All(
|
||||
{
|
||||
vol.Required(EVENT_START): vol.Any(cv.date, cv.datetime),
|
||||
vol.Required(EVENT_END): vol.Any(cv.date, cv.datetime),
|
||||
vol.Required(EVENT_SUMMARY): cv.string,
|
||||
vol.Optional(EVENT_DESCRIPTION): cv.string,
|
||||
vol.Optional(EVENT_RRULE): _validate_rrule,
|
||||
},
|
||||
_has_same_type(EVENT_START, EVENT_END),
|
||||
_has_consistent_timezone(EVENT_START, EVENT_END),
|
||||
_as_local_timezone(EVENT_START, EVENT_END),
|
||||
_has_duration(EVENT_START, EVENT_END),
|
||||
)
|
||||
)
|
||||
|
||||
# Validation for the CalendarEvent dataclass
|
||||
CALENDAR_EVENT_SCHEMA = vol.Schema(
|
||||
vol.All(
|
||||
{
|
||||
vol.Required("start"): vol.Any(cv.date, cv.datetime),
|
||||
vol.Required("end"): vol.Any(cv.date, cv.datetime),
|
||||
vol.Required(EVENT_SUMMARY): cv.string,
|
||||
vol.Optional(EVENT_RRULE): _validate_rrule,
|
||||
},
|
||||
_has_same_type("start", "end"),
|
||||
_has_timezone("start", "end"),
|
||||
_has_consistent_timezone("start", "end"),
|
||||
_as_local_timezone("start", "end"),
|
||||
_has_duration("start", "end"),
|
||||
),
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
)
|
||||
|
||||
|
||||
@ -243,6 +330,19 @@ class CalendarEvent:
|
||||
"all_day": self.all_day,
|
||||
}
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Perform validation on the CalendarEvent."""
|
||||
|
||||
def skip_none(obj: Iterable[tuple[str, Any]]) -> dict[str, str]:
|
||||
return {k: v for k, v in obj if v is not None}
|
||||
|
||||
try:
|
||||
CALENDAR_EVENT_SCHEMA(dataclasses.asdict(self, dict_factory=skip_none))
|
||||
except vol.Invalid as err:
|
||||
raise HomeAssistantError(
|
||||
f"Failed to validate CalendarEvent: {err}"
|
||||
) from err
|
||||
|
||||
|
||||
def _event_dict_factory(obj: Iterable[tuple[str, Any]]) -> dict[str, str]:
|
||||
"""Convert CalendarEvent dataclass items to dictionary of attributes."""
|
||||
@ -316,30 +416,6 @@ def is_offset_reached(
|
||||
return start + offset_time <= dt.now(start.tzinfo)
|
||||
|
||||
|
||||
def _validate_rrule(value: Any) -> str:
|
||||
"""Validate a recurrence rule string."""
|
||||
if value is None:
|
||||
raise vol.Invalid("rrule value is None")
|
||||
|
||||
if not isinstance(value, str):
|
||||
raise vol.Invalid("rrule value expected a string")
|
||||
|
||||
try:
|
||||
rrulestr(value)
|
||||
except ValueError as err:
|
||||
raise vol.Invalid(f"Invalid rrule: {str(err)}") from err
|
||||
|
||||
# Example format: FREQ=DAILY;UNTIL=...
|
||||
rule_parts = dict(s.split("=", 1) for s in value.split(";"))
|
||||
if not (freq := rule_parts.get("FREQ")):
|
||||
raise vol.Invalid("rrule did not contain FREQ")
|
||||
|
||||
if freq not in VALID_FREQS:
|
||||
raise vol.Invalid(f"Invalid frequency for rule: {value}")
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
class CalendarEntity(Entity):
|
||||
"""Base class for calendar event entities."""
|
||||
|
||||
@ -447,6 +523,7 @@ class CalendarEventView(http.HomeAssistantView):
|
||||
request.app["hass"], start_date, end_date
|
||||
)
|
||||
except HomeAssistantError as err:
|
||||
_LOGGER.debug("Error reading events: %s", err)
|
||||
return self.json_message(
|
||||
f"Error reading events: {err}", HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
)
|
||||
@ -481,38 +558,11 @@ class CalendarListView(http.HomeAssistantView):
|
||||
return self.json(sorted(calendar_list, key=lambda x: cast(str, x["name"])))
|
||||
|
||||
|
||||
def _has_same_type(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
|
||||
"""Verify that all values are of the same type."""
|
||||
|
||||
def validate(obj: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Test that all keys in the dict have values of the same type."""
|
||||
uniq_values = groupby(type(obj[k]) for k in keys)
|
||||
if len(list(uniq_values)) > 1:
|
||||
raise vol.Invalid(f"Expected all values to be the same type: {keys}")
|
||||
return obj
|
||||
|
||||
return validate
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "calendar/event/create",
|
||||
vol.Required("entity_id"): cv.entity_id,
|
||||
CONF_EVENT: vol.Schema(
|
||||
vol.All(
|
||||
{
|
||||
vol.Required(EVENT_START): vol.Any(cv.date, cv.datetime),
|
||||
vol.Required(EVENT_END): vol.Any(cv.date, cv.datetime),
|
||||
vol.Required(EVENT_SUMMARY): cv.string,
|
||||
vol.Optional(EVENT_DESCRIPTION): cv.string,
|
||||
vol.Optional(EVENT_RRULE): _validate_rrule,
|
||||
},
|
||||
_has_same_type(EVENT_START, EVENT_END),
|
||||
_has_consistent_timezone(EVENT_START, EVENT_END),
|
||||
_as_local_timezone(EVENT_START, EVENT_END),
|
||||
_is_sorted(EVENT_START, EVENT_END),
|
||||
)
|
||||
),
|
||||
CONF_EVENT: WEBSOCKET_EVENT_SCHEMA,
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
@ -595,21 +645,7 @@ async def handle_calendar_event_delete(
|
||||
vol.Required(EVENT_UID): cv.string,
|
||||
vol.Optional(EVENT_RECURRENCE_ID): cv.string,
|
||||
vol.Optional(EVENT_RECURRENCE_RANGE): cv.string,
|
||||
vol.Required(CONF_EVENT): vol.Schema(
|
||||
vol.All(
|
||||
{
|
||||
vol.Required(EVENT_START): vol.Any(cv.date, cv.datetime),
|
||||
vol.Required(EVENT_END): vol.Any(cv.date, cv.datetime),
|
||||
vol.Required(EVENT_SUMMARY): cv.string,
|
||||
vol.Optional(EVENT_DESCRIPTION): cv.string,
|
||||
vol.Optional(EVENT_RRULE): _validate_rrule,
|
||||
},
|
||||
_has_same_type(EVENT_START, EVENT_END),
|
||||
_has_consistent_timezone(EVENT_START, EVENT_END),
|
||||
_as_local_timezone(EVENT_START, EVENT_END),
|
||||
_is_sorted(EVENT_START, EVENT_END),
|
||||
)
|
||||
),
|
||||
vol.Required(CONF_EVENT): WEBSOCKET_EVENT_SCHEMA,
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
|
@ -324,7 +324,7 @@ async def test_unsupported_create_event_service(hass: HomeAssistant) -> None:
|
||||
"end_date_time": "2022-04-01T06:00:00",
|
||||
},
|
||||
vol.error.MultipleInvalid,
|
||||
"Values were not in order",
|
||||
"Expected positive event duration",
|
||||
),
|
||||
(
|
||||
{
|
||||
@ -332,7 +332,15 @@ async def test_unsupported_create_event_service(hass: HomeAssistant) -> None:
|
||||
"end_date": "2022-04-01",
|
||||
},
|
||||
vol.error.MultipleInvalid,
|
||||
"Values were not in order",
|
||||
"Expected positive event duration",
|
||||
),
|
||||
(
|
||||
{
|
||||
"start_date": "2022-04-01",
|
||||
"end_date": "2022-04-01",
|
||||
},
|
||||
vol.error.MultipleInvalid,
|
||||
"Expected positive event duration",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
@ -351,6 +359,7 @@ async def test_unsupported_create_event_service(hass: HomeAssistant) -> None:
|
||||
"inconsistent_timezone",
|
||||
"incorrect_date_order",
|
||||
"incorrect_datetime_order",
|
||||
"dates_not_exclusive",
|
||||
],
|
||||
)
|
||||
async def test_create_event_service_invalid_params(
|
||||
|
@ -597,7 +597,7 @@ async def test_add_event_failure(
|
||||
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await add_event_call_service(
|
||||
{"start_date": "2022-05-01", "end_date": "2022-05-01"}
|
||||
{"start_date": "2022-05-01", "end_date": "2022-05-02"}
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user