Add WS API for debugging previous assist_pipeline runs (#91541)

* Add WS API for debugging previous assist_pipeline runs

* Improve typing
This commit is contained in:
Erik Montnemery
2023-04-17 17:48:02 +02:00
committed by GitHub
parent b597415b01
commit 0ecd23baee
7 changed files with 564 additions and 32 deletions

View File

@@ -12,7 +12,9 @@ from homeassistant.components import stt, websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv
from .const import DOMAIN
from .pipeline import (
PipelineData,
PipelineError,
PipelineEvent,
PipelineEventType,
@@ -69,6 +71,8 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
),
),
)
websocket_api.async_register_command(hass, websocket_list_runs)
websocket_api.async_register_command(hass, websocket_get_run)
@websocket_api.async_response
@@ -193,14 +197,82 @@ async def websocket_run(
async with async_timeout.timeout(timeout):
await run_task
except asyncio.TimeoutError:
connection.send_event(
msg["id"],
pipeline_input.run.process_event(
PipelineEvent(
PipelineEventType.ERROR,
{"code": "timeout", "message": "Timeout running pipeline"},
),
)
)
finally:
if unregister_handler is not None:
# Unregister binary handler
unregister_handler()
@callback
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_pipeline/pipeline_debug/list",
vol.Required("pipeline_id"): str,
}
)
def websocket_list_runs(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""List pipeline runs for which debug data is available."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = msg["pipeline_id"]
if pipeline_id not in pipeline_data.pipeline_runs:
connection.send_result(msg["id"], {"pipeline_runs": []})
return
pipeline_runs = pipeline_data.pipeline_runs[pipeline_id]
connection.send_result(msg["id"], {"pipeline_runs": list(pipeline_runs)})
@callback
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_pipeline/pipeline_debug/get",
vol.Required("pipeline_id"): str,
vol.Required("pipeline_run_id"): str,
}
)
def websocket_get_run(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Get debug data for a pipeline run."""
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = msg["pipeline_id"]
pipeline_run_id = msg["pipeline_run_id"]
if pipeline_id not in pipeline_data.pipeline_runs:
connection.send_error(
msg["id"],
websocket_api.const.ERR_NOT_FOUND,
f"pipeline_id {pipeline_id} not found",
)
return
pipeline_runs = pipeline_data.pipeline_runs[pipeline_id]
if pipeline_run_id not in pipeline_runs:
connection.send_error(
msg["id"],
websocket_api.const.ERR_NOT_FOUND,
f"pipeline_run_id {pipeline_run_id} not found",
)
return
connection.send_result(
msg["id"],
{"events": pipeline_runs[pipeline_run_id]},
)