Use pipeline ID in event (#92100)

* Use pipeline ID in event

* Fix tests
This commit is contained in:
Paulus Schoutsen 2023-04-26 22:40:17 -04:00 committed by GitHub
parent 32ed45084a
commit 7c696754ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 40 additions and 38 deletions

View File

@ -370,7 +370,7 @@ class PipelineRun:
def start(self) -> None: def start(self) -> None:
"""Emit run start event.""" """Emit run start event."""
data = { data = {
"pipeline": self.pipeline.name, "pipeline": self.pipeline.id,
"language": self.language, "language": self.language,
} }
if self.runner_data is not None: if self.runner_data is not None:

View File

@ -1,5 +1,4 @@
"""Tests for the Voice Assistant integration.""" """Tests for the Voice Assistant integration."""
MANY_LANGUAGES = [ MANY_LANGUAGES = [
"ar", "ar",
"bg", "bg",

View File

@ -4,7 +4,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'language': 'en', 'language': 'en',
'pipeline': 'Home Assistant', 'pipeline': <ANY>,
}), }),
'type': <PipelineEventType.RUN_START: 'run-start'>, 'type': <PipelineEventType.RUN_START: 'run-start'>,
}), }),
@ -91,7 +91,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'language': 'en', 'language': 'en',
'pipeline': 'test_name', 'pipeline': <ANY>,
}), }),
'type': <PipelineEventType.RUN_START: 'run-start'>, 'type': <PipelineEventType.RUN_START: 'run-start'>,
}), }),
@ -178,7 +178,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'language': 'en', 'language': 'en',
'pipeline': 'test_name', 'pipeline': <ANY>,
}), }),
'type': <PipelineEventType.RUN_START: 'run-start'>, 'type': <PipelineEventType.RUN_START: 'run-start'>,
}), }),

View File

@ -2,7 +2,7 @@
# name: test_audio_pipeline # name: test_audio_pipeline
dict({ dict({
'language': 'en', 'language': 'en',
'pipeline': 'Home Assistant', 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': 1, 'stt_binary_handler_id': 1,
'timeout': 30, 'timeout': 30,
@ -78,7 +78,7 @@
# name: test_audio_pipeline_debug # name: test_audio_pipeline_debug
dict({ dict({
'language': 'en', 'language': 'en',
'pipeline': 'Home Assistant', 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': 1, 'stt_binary_handler_id': 1,
'timeout': 30, 'timeout': 30,
@ -154,7 +154,7 @@
# name: test_intent_failed # name: test_intent_failed
dict({ dict({
'language': 'en', 'language': 'en',
'pipeline': 'Home Assistant', 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': None, 'stt_binary_handler_id': None,
'timeout': 30, 'timeout': 30,
@ -171,7 +171,7 @@
# name: test_intent_timeout # name: test_intent_timeout
dict({ dict({
'language': 'en', 'language': 'en',
'pipeline': 'Home Assistant', 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': None, 'stt_binary_handler_id': None,
'timeout': 0.1, 'timeout': 0.1,
@ -217,7 +217,7 @@
# name: test_stt_stream_failed # name: test_stt_stream_failed
dict({ dict({
'language': 'en', 'language': 'en',
'pipeline': 'Home Assistant', 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': 1, 'stt_binary_handler_id': 1,
'timeout': 30, 'timeout': 30,
@ -240,7 +240,7 @@
# name: test_text_only_pipeline # name: test_text_only_pipeline
dict({ dict({
'language': 'en', 'language': 'en',
'pipeline': 'Home Assistant', 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': None, 'stt_binary_handler_id': None,
'timeout': 30, 'timeout': 30,
@ -285,7 +285,7 @@
# name: test_tts_failed # name: test_tts_failed
dict({ dict({
'language': 'en', 'language': 'en',
'pipeline': 'Home Assistant', 'pipeline': <ANY>,
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': None, 'stt_binary_handler_id': None,
'timeout': 30, 'timeout': 30,

View File

@ -1,5 +1,6 @@
"""Test Voice Assistant init.""" """Test Voice Assistant init."""
from dataclasses import asdict from dataclasses import asdict
from unittest.mock import ANY
import pytest import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
@ -12,6 +13,19 @@ from .conftest import MockSttProvider, MockSttProviderEntity
from tests.typing import WebSocketGenerator from tests.typing import WebSocketGenerator
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
"""Process events to remove dynamic values."""
processed = []
for event in events:
as_dict = asdict(event)
as_dict.pop("timestamp")
if as_dict["type"] == assist_pipeline.PipelineEventType.RUN_START:
as_dict["data"]["pipeline"] = ANY
processed.append(as_dict)
return processed
async def test_pipeline_from_audio_stream_auto( async def test_pipeline_from_audio_stream_auto(
hass: HomeAssistant, hass: HomeAssistant,
mock_stt_provider: MockSttProvider, mock_stt_provider: MockSttProvider,
@ -45,13 +59,7 @@ async def test_pipeline_from_audio_stream_auto(
audio_data(), audio_data(),
) )
processed = [] assert process_events(events) == snapshot
for event in events:
as_dict = asdict(event)
as_dict.pop("timestamp")
processed.append(as_dict)
assert processed == snapshot
assert mock_stt_provider.received == [b"part1", b"part2"] assert mock_stt_provider.received == [b"part1", b"part2"]
@ -111,13 +119,7 @@ async def test_pipeline_from_audio_stream_legacy(
pipeline_id=pipeline_id, pipeline_id=pipeline_id,
) )
processed = [] assert process_events(events) == snapshot
for event in events:
as_dict = asdict(event)
as_dict.pop("timestamp")
processed.append(as_dict)
assert processed == snapshot
assert mock_stt_provider.received == [b"part1", b"part2"] assert mock_stt_provider.received == [b"part1", b"part2"]
@ -177,13 +179,7 @@ async def test_pipeline_from_audio_stream_entity(
pipeline_id=pipeline_id, pipeline_id=pipeline_id,
) )
processed = [] assert process_events(events) == snapshot
for event in events:
as_dict = asdict(event)
as_dict.pop("timestamp")
processed.append(as_dict)
assert processed == snapshot
assert mock_stt_provider_entity.received == [b"part1", b"part2"] assert mock_stt_provider_entity.received == [b"part1", b"part2"]

View File

@ -1,6 +1,6 @@
"""Websocket tests for Voice Assistant integration.""" """Websocket tests for Voice Assistant integration."""
import asyncio import asyncio
from unittest.mock import ANY, MagicMock, patch from unittest.mock import ANY, patch
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
@ -37,6 +37,7 @@ async def test_text_only_pipeline(
# run start # run start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
@ -101,6 +102,7 @@ async def test_audio_pipeline(
# run start # run start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
@ -196,6 +198,7 @@ async def test_intent_timeout(
# run start # run start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
@ -292,7 +295,7 @@ async def test_intent_failed(
with patch( with patch(
"homeassistant.components.conversation.async_converse", "homeassistant.components.conversation.async_converse",
new=MagicMock(return_value=RuntimeError), side_effect=RuntimeError,
): ):
await client.send_json_auto_id( await client.send_json_auto_id(
{ {
@ -310,6 +313,7 @@ async def test_intent_failed(
# run start # run start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
@ -405,7 +409,7 @@ async def test_stt_provider_missing(
"""Test events from a pipeline run with a non-existent STT provider.""" """Test events from a pipeline run with a non-existent STT provider."""
with patch( with patch(
"homeassistant.components.stt.async_get_provider", "homeassistant.components.stt.async_get_provider",
new=MagicMock(return_value=None), return_value=None,
): ):
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
@ -438,7 +442,7 @@ async def test_stt_stream_failed(
with patch( with patch(
"tests.components.assist_pipeline.conftest.MockSttProvider.async_process_audio_stream", "tests.components.assist_pipeline.conftest.MockSttProvider.async_process_audio_stream",
new=MagicMock(side_effect=RuntimeError), side_effect=RuntimeError,
): ):
await client.send_json_auto_id( await client.send_json_auto_id(
{ {
@ -458,6 +462,7 @@ async def test_stt_stream_failed(
# run start # run start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
@ -504,7 +509,7 @@ async def test_tts_failed(
with patch( with patch(
"homeassistant.components.media_source.async_resolve_media", "homeassistant.components.media_source.async_resolve_media",
new=MagicMock(return_value=RuntimeError), side_effect=RuntimeError,
): ):
await client.send_json_auto_id( await client.send_json_auto_id(
{ {
@ -522,6 +527,7 @@ async def test_tts_failed(
# run start # run start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
@ -1105,6 +1111,7 @@ async def test_audio_pipeline_debug(
# run start # run start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])