mirror of
https://github.com/home-assistant/core.git
synced 2025-07-27 07:07:28 +00:00
Allow complex schemas for validating WS commands (#91655)
This commit is contained in:
parent
90e92aa9d8
commit
4e0b8a7363
@ -33,10 +33,12 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
@callback
|
@callback
|
||||||
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||||
"""Register the websocket API."""
|
"""Register the websocket API."""
|
||||||
websocket_api.async_register_command(
|
websocket_api.async_register_command(hass, websocket_run)
|
||||||
hass,
|
websocket_api.async_register_command(hass, websocket_list_runs)
|
||||||
"assist_pipeline/run",
|
websocket_api.async_register_command(hass, websocket_get_run)
|
||||||
websocket_run,
|
|
||||||
|
|
||||||
|
@websocket_api.websocket_command(
|
||||||
vol.All(
|
vol.All(
|
||||||
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
|
||||||
{
|
{
|
||||||
@ -69,11 +71,7 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
websocket_api.async_register_command(hass, websocket_list_runs)
|
|
||||||
websocket_api.async_register_command(hass, websocket_get_run)
|
|
||||||
|
|
||||||
|
|
||||||
@websocket_api.async_response
|
@websocket_api.async_response
|
||||||
async def websocket_run(
|
async def websocket_run(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
@ -128,15 +128,31 @@ def ws_require_user(
|
|||||||
|
|
||||||
|
|
||||||
def websocket_command(
|
def websocket_command(
|
||||||
schema: dict[vol.Marker, Any],
|
schema: dict[vol.Marker, Any] | vol.All,
|
||||||
) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]:
|
) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]:
|
||||||
"""Tag a function as a websocket command."""
|
"""Tag a function as a websocket command.
|
||||||
|
|
||||||
|
The schema must be either a dictionary where the keys are voluptuous markers, or
|
||||||
|
a voluptuous.All schema where the first item is a voluptuous Mapping schema.
|
||||||
|
"""
|
||||||
|
if isinstance(schema, dict):
|
||||||
command = schema["type"]
|
command = schema["type"]
|
||||||
|
else:
|
||||||
|
command = schema.validators[0].schema["type"]
|
||||||
|
|
||||||
def decorate(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler:
|
def decorate(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler:
|
||||||
"""Decorate ws command function."""
|
"""Decorate ws command function."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
if isinstance(schema, dict):
|
||||||
func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) # type: ignore[attr-defined]
|
func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) # type: ignore[attr-defined]
|
||||||
|
else:
|
||||||
|
extended_schema = vol.All(
|
||||||
|
schema.validators[0].extend(
|
||||||
|
messages.BASE_COMMAND_MESSAGE_SCHEMA.schema
|
||||||
|
),
|
||||||
|
*schema.validators[1:],
|
||||||
|
)
|
||||||
|
func._ws_schema = extended_schema # type: ignore[attr-defined]
|
||||||
func._ws_command = command # type: ignore[attr-defined]
|
func._ws_command = command # type: ignore[attr-defined]
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user