diff --git a/homeassistant/components/assist_pipeline/websocket_api.py b/homeassistant/components/assist_pipeline/websocket_api.py index b672e0c6b25..adee30c2012 100644 --- a/homeassistant/components/assist_pipeline/websocket_api.py +++ b/homeassistant/components/assist_pipeline/websocket_api.py @@ -33,47 +33,45 @@ _LOGGER = logging.getLogger(__name__) @callback def async_register_websocket_api(hass: HomeAssistant) -> None: """Register the websocket API.""" - websocket_api.async_register_command( - hass, - "assist_pipeline/run", - websocket_run, - vol.All( - websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( - { - vol.Required("type"): "assist_pipeline/run", - # pylint: disable-next=unnecessary-lambda - vol.Required("start_stage"): lambda val: PipelineStage(val), - # pylint: disable-next=unnecessary-lambda - vol.Required("end_stage"): lambda val: PipelineStage(val), - vol.Optional("input"): dict, - vol.Optional("pipeline"): str, - vol.Optional("conversation_id"): vol.Any(str, None), - vol.Optional("timeout"): vol.Any(float, int), - }, - ), - cv.key_value_schemas( - "start_stage", - { - PipelineStage.STT: vol.Schema( - {vol.Required("input"): {vol.Required("sample_rate"): int}}, - extra=vol.ALLOW_EXTRA, - ), - PipelineStage.INTENT: vol.Schema( - {vol.Required("input"): {"text": str}}, - extra=vol.ALLOW_EXTRA, - ), - PipelineStage.TTS: vol.Schema( - {vol.Required("input"): {"text": str}}, - extra=vol.ALLOW_EXTRA, - ), - }, - ), - ), - ) + websocket_api.async_register_command(hass, websocket_run) websocket_api.async_register_command(hass, websocket_list_runs) websocket_api.async_register_command(hass, websocket_get_run) +@websocket_api.websocket_command( + vol.All( + websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( + { + vol.Required("type"): "assist_pipeline/run", + # pylint: disable-next=unnecessary-lambda + vol.Required("start_stage"): lambda val: PipelineStage(val), + # pylint: disable-next=unnecessary-lambda + vol.Required("end_stage"): lambda val: PipelineStage(val), + vol.Optional("input"): dict, + vol.Optional("pipeline"): str, + vol.Optional("conversation_id"): vol.Any(str, None), + vol.Optional("timeout"): vol.Any(float, int), + }, + ), + cv.key_value_schemas( + "start_stage", + { + PipelineStage.STT: vol.Schema( + {vol.Required("input"): {vol.Required("sample_rate"): int}}, + extra=vol.ALLOW_EXTRA, + ), + PipelineStage.INTENT: vol.Schema( + {vol.Required("input"): {"text": str}}, + extra=vol.ALLOW_EXTRA, + ), + PipelineStage.TTS: vol.Schema( + {vol.Required("input"): {"text": str}}, + extra=vol.ALLOW_EXTRA, + ), + }, + ), + ), +) @websocket_api.async_response async def websocket_run( hass: HomeAssistant, diff --git a/homeassistant/components/websocket_api/decorators.py b/homeassistant/components/websocket_api/decorators.py index 9afffd9fb28..a148ed2be8d 100644 --- a/homeassistant/components/websocket_api/decorators.py +++ b/homeassistant/components/websocket_api/decorators.py @@ -128,15 +128,31 @@ def ws_require_user( def websocket_command( - schema: dict[vol.Marker, Any], + schema: dict[vol.Marker, Any] | vol.All, ) -> Callable[[const.WebSocketCommandHandler], const.WebSocketCommandHandler]: - """Tag a function as a websocket command.""" - command = schema["type"] + """Tag a function as a websocket command. + + The schema must be either a dictionary where the keys are voluptuous markers, or + a voluptuous.All schema where the first item is a voluptuous Mapping schema. + """ + if isinstance(schema, dict): + command = schema["type"] + else: + command = schema.validators[0].schema["type"] def decorate(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler: """Decorate ws command function.""" # pylint: disable=protected-access - func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) # type: ignore[attr-defined] + if isinstance(schema, dict): + func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) # type: ignore[attr-defined] + else: + extended_schema = vol.All( + schema.validators[0].extend( + messages.BASE_COMMAND_MESSAGE_SCHEMA.schema + ), + *schema.validators[1:], + ) + func._ws_schema = extended_schema # type: ignore[attr-defined] func._ws_command = command # type: ignore[attr-defined] return func