Clean up voice assistant integration (#90239)

* Clean up voice assistant

* Reinstate auto-removed imports

* Resample STT audio from 44.1Khz to 16Khz

* Energy based VAD for prototyping

---------

Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
Paulus Schoutsen 2023-03-26 22:41:17 -04:00 committed by GitHub
parent 7098debe09
commit c3717f8182
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 407 additions and 237 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncIterable from collections.abc import AsyncIterable
import logging
from hass_nabucasa import Cloud from hass_nabucasa import Cloud
from hass_nabucasa.voice import VoiceError from hass_nabucasa.voice import VoiceError
@ -20,6 +21,8 @@ from homeassistant.components.stt import (
from .const import DOMAIN from .const import DOMAIN
_LOGGER = logging.getLogger(__name__)
SUPPORT_LANGUAGES = [ SUPPORT_LANGUAGES = [
"da-DK", "da-DK",
"de-DE", "de-DE",
@ -102,7 +105,8 @@ class CloudProvider(Provider):
result = await self.cloud.voice.process_stt( result = await self.cloud.voice.process_stt(
stream, content, metadata.language stream, content, metadata.language
) )
except VoiceError: except VoiceError as err:
_LOGGER.debug("Voice error: %s", err)
return SpeechResult(None, SpeechResultState.ERROR) return SpeechResult(None, SpeechResultState.ERROR)
# Return Speech as Text # Return Speech as Text

View File

@ -150,6 +150,7 @@ class PipelineRun:
end_stage: PipelineStage end_stage: PipelineStage
event_callback: Callable[[PipelineEvent], None] event_callback: Callable[[PipelineEvent], None]
language: str = None # type: ignore[assignment] language: str = None # type: ignore[assignment]
runner_data: Any | None = None
def __post_init__(self): def __post_init__(self):
"""Set language for pipeline.""" """Set language for pipeline."""
@ -163,15 +164,14 @@ class PipelineRun:
def start(self): def start(self):
"""Emit run start event.""" """Emit run start event."""
self.event_callback( data = {
PipelineEvent( "pipeline": self.pipeline.name,
PipelineEventType.RUN_START, "language": self.language,
{ }
"pipeline": self.pipeline.name, if self.runner_data is not None:
"language": self.language, data["runner_data"] = self.runner_data
},
) self.event_callback(PipelineEvent(PipelineEventType.RUN_START, data))
)
def end(self): def end(self):
"""Emit run end event.""" """Emit run end event."""
@ -200,41 +200,45 @@ class PipelineRun:
try: try:
# Load provider # Load provider
stt_provider = stt.async_get_provider(self.hass, self.pipeline.stt_engine) stt_provider: stt.Provider = stt.async_get_provider(
self.hass, self.pipeline.stt_engine
)
assert stt_provider is not None assert stt_provider is not None
except Exception as src_error: except Exception as src_error:
stt_error = SpeechToTextError( _LOGGER.exception("No speech to text provider for %s", engine)
raise SpeechToTextError(
code="stt-provider-missing", code="stt-provider-missing",
message=f"No speech to text provider for: {engine}", message=f"No speech to text provider for: {engine}",
) from src_error
if not stt_provider.check_metadata(metadata):
raise SpeechToTextError(
code="stt-provider-unsupported-metadata",
message=f"Provider {engine} does not support input speech to text metadata",
) )
_LOGGER.exception(stt_error.message)
self.event_callback(
PipelineEvent(
PipelineEventType.ERROR,
{"code": stt_error.code, "message": stt_error.message},
)
)
raise stt_error from src_error
try: try:
# Transcribe audio stream # Transcribe audio stream
result = await stt_provider.async_process_audio_stream(metadata, stream) result = await stt_provider.async_process_audio_stream(metadata, stream)
assert (result.text is not None) and (
result.result == stt.SpeechResultState.SUCCESS
)
except Exception as src_error: except Exception as src_error:
stt_error = SpeechToTextError( _LOGGER.exception("Unexpected error during speech to text")
raise SpeechToTextError(
code="stt-stream-failed", code="stt-stream-failed",
message="Unexpected error during speech to text", message="Unexpected error during speech to text",
) from src_error
_LOGGER.debug("speech-to-text result %s", result)
if result.result != stt.SpeechResultState.SUCCESS:
raise SpeechToTextError(
code="stt-stream-failed",
message="Speech to text failed",
) )
_LOGGER.exception(stt_error.message)
self.event_callback( if not result.text:
PipelineEvent( raise SpeechToTextError(
PipelineEventType.ERROR, code="stt-no-text-recognized", message="No text recognized"
{"code": stt_error.code, "message": stt_error.message},
)
) )
raise stt_error from src_error
self.event_callback( self.event_callback(
PipelineEvent( PipelineEvent(
@ -273,18 +277,13 @@ class PipelineRun:
agent_id=self.pipeline.conversation_engine, agent_id=self.pipeline.conversation_engine,
) )
except Exception as src_error: except Exception as src_error:
intent_error = IntentRecognitionError( _LOGGER.exception("Unexpected error during intent recognition")
raise IntentRecognitionError(
code="intent-failed", code="intent-failed",
message="Unexpected error during intent recognition", message="Unexpected error during intent recognition",
) ) from src_error
_LOGGER.exception(intent_error.message)
self.event_callback( _LOGGER.debug("conversation result %s", conversation_result)
PipelineEvent(
PipelineEventType.ERROR,
{"code": intent_error.code, "message": intent_error.message},
)
)
raise intent_error from src_error
self.event_callback( self.event_callback(
PipelineEvent( PipelineEvent(
@ -320,18 +319,13 @@ class PipelineRun:
), ),
) )
except Exception as src_error: except Exception as src_error:
tts_error = TextToSpeechError( _LOGGER.exception("Unexpected error during text to speech")
raise TextToSpeechError(
code="tts-failed", code="tts-failed",
message="Unexpected error during text to speech", message="Unexpected error during text to speech",
) ) from src_error
_LOGGER.exception(tts_error.message)
self.event_callback( _LOGGER.debug("TTS result %s", tts_media)
PipelineEvent(
PipelineEventType.ERROR,
{"code": tts_error.code, "message": tts_error.message},
)
)
raise tts_error from src_error
self.event_callback( self.event_callback(
PipelineEvent( PipelineEvent(
@ -377,31 +371,41 @@ class PipelineInput:
run.start() run.start()
current_stage = run.start_stage current_stage = run.start_stage
# Speech to text try:
intent_input = self.intent_input # Speech to text
if current_stage == PipelineStage.STT: intent_input = self.intent_input
assert self.stt_metadata is not None if current_stage == PipelineStage.STT:
assert self.stt_stream is not None assert self.stt_metadata is not None
intent_input = await run.speech_to_text( assert self.stt_stream is not None
self.stt_metadata, intent_input = await run.speech_to_text(
self.stt_stream, self.stt_metadata,
) self.stt_stream,
current_stage = PipelineStage.INTENT
if run.end_stage != PipelineStage.STT:
tts_input = self.tts_input
if current_stage == PipelineStage.INTENT:
assert intent_input is not None
tts_input = await run.recognize_intent(
intent_input, self.conversation_id
) )
current_stage = PipelineStage.TTS current_stage = PipelineStage.INTENT
if run.end_stage != PipelineStage.INTENT: if run.end_stage != PipelineStage.STT:
if current_stage == PipelineStage.TTS: tts_input = self.tts_input
assert tts_input is not None
await run.text_to_speech(tts_input) if current_stage == PipelineStage.INTENT:
assert intent_input is not None
tts_input = await run.recognize_intent(
intent_input, self.conversation_id
)
current_stage = PipelineStage.TTS
if run.end_stage != PipelineStage.INTENT:
if current_stage == PipelineStage.TTS:
assert tts_input is not None
await run.text_to_speech(tts_input)
except PipelineError as err:
run.event_callback(
PipelineEvent(
PipelineEventType.ERROR,
{"code": err.code, "message": err.message},
)
)
return
run.end() run.end()

View File

@ -1,5 +1,6 @@
"""Voice Assistant Websocket API.""" """Voice Assistant Websocket API."""
import asyncio import asyncio
import audioop
from collections.abc import Callable from collections.abc import Callable
import logging import logging
from typing import Any from typing import Any
@ -12,6 +13,8 @@ from homeassistant.core import HomeAssistant, callback
from .pipeline import ( from .pipeline import (
DEFAULT_TIMEOUT, DEFAULT_TIMEOUT,
PipelineError, PipelineError,
PipelineEvent,
PipelineEventType,
PipelineInput, PipelineInput,
PipelineRun, PipelineRun,
PipelineStage, PipelineStage,
@ -20,6 +23,10 @@ from .pipeline import (
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_VAD_ENERGY_THRESHOLD = 1000
_VAD_SPEECH_FRAMES = 25
_VAD_SILENCE_FRAMES = 25
@callback @callback
def async_register_websocket_api(hass: HomeAssistant) -> None: def async_register_websocket_api(hass: HomeAssistant) -> None:
@ -27,6 +34,17 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
websocket_api.async_register_command(hass, websocket_run) websocket_api.async_register_command(hass, websocket_run)
def _get_debiased_energy(audio_data: bytes, width: int = 2) -> float:
"""Compute RMS of debiased audio."""
energy = -audioop.rms(audio_data, width)
energy_bytes = bytes([energy & 0xFF, (energy >> 8) & 0xFF])
debiased_energy = audioop.rms(
audioop.add(audio_data, energy_bytes * (len(audio_data) // width), width), width
)
return debiased_energy
@websocket_api.websocket_command( @websocket_api.websocket_command(
{ {
vol.Required("type"): "voice_assistant/run", vol.Required("type"): "voice_assistant/run",
@ -49,6 +67,11 @@ async def websocket_run(
) -> None: ) -> None:
"""Run a pipeline.""" """Run a pipeline."""
language = msg.get("language", hass.config.language) language = msg.get("language", hass.config.language)
# Temporary workaround for language codes
if language == "en":
language = "en-US"
pipeline_id = msg.get("pipeline") pipeline_id = msg.get("pipeline")
pipeline = async_get_pipeline( pipeline = async_get_pipeline(
hass, hass,
@ -79,8 +102,32 @@ async def websocket_run(
audio_queue: "asyncio.Queue[bytes]" = asyncio.Queue() audio_queue: "asyncio.Queue[bytes]" = asyncio.Queue()
async def stt_stream(): async def stt_stream():
state = None
speech_count = 0
in_voice_command = False
# Yield until we receive an empty chunk # Yield until we receive an empty chunk
while chunk := await audio_queue.get(): while chunk := await audio_queue.get():
chunk, state = audioop.ratecv(chunk, 2, 1, 44100, 16000, state)
is_speech = _get_debiased_energy(chunk) > _VAD_ENERGY_THRESHOLD
if in_voice_command:
if is_speech:
speech_count += 1
else:
speech_count -= 1
if speech_count <= -_VAD_SILENCE_FRAMES:
_LOGGER.info("Voice command stopped")
break
else:
if is_speech:
speech_count += 1
if speech_count >= _VAD_SPEECH_FRAMES:
in_voice_command = True
_LOGGER.info("Voice command started")
yield chunk yield chunk
def handle_binary(_hass, _connection, data: bytes): def handle_binary(_hass, _connection, data: bytes):
@ -119,6 +166,9 @@ async def websocket_run(
event_callback=lambda event: connection.send_event( event_callback=lambda event: connection.send_event(
msg["id"], event.as_dict() msg["id"], event.as_dict()
), ),
runner_data={
"stt_binary_handler_id": handler_id,
},
), ),
timeout=timeout, timeout=timeout,
) )
@ -130,16 +180,20 @@ async def websocket_run(
# Confirm subscription # Confirm subscription
connection.send_result(msg["id"]) connection.send_result(msg["id"])
if handler_id is not None:
# Send handler id to client
connection.send_event(msg["id"], {"handler_id": handler_id})
try: try:
# Task contains a timeout # Task contains a timeout
await run_task await run_task
except PipelineError as error: except PipelineError as error:
# Report more specific error when possible # Report more specific error when possible
connection.send_error(msg["id"], error.code, error.message) connection.send_error(msg["id"], error.code, error.message)
except asyncio.TimeoutError:
connection.send_event(
msg["id"],
PipelineEvent(
PipelineEventType.ERROR,
{"code": "timeout", "message": "Timeout running pipeline"},
),
)
finally: finally:
if unregister_handler is not None: if unregister_handler is not None:
# Unregister binary handler # Unregister binary handler

View File

@ -0,0 +1,210 @@
# serializer version: 1
# name: test_audio_pipeline
dict({
'language': 'en-US',
'pipeline': 'en-US',
'runner_data': dict({
'stt_binary_handler_id': 1,
}),
})
# ---
# name: test_audio_pipeline.1
dict({
'engine': 'default',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'language': 'en-US',
'sample_rate': 16000,
}),
})
# ---
# name: test_audio_pipeline.2
dict({
'stt_output': dict({
'text': 'test transcript',
}),
})
# ---
# name: test_audio_pipeline.3
dict({
'engine': 'default',
'intent_input': 'test transcript',
})
# ---
# name: test_audio_pipeline.4
dict({
'intent_output': dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en-US',
'response_type': 'error',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': "Sorry, I couldn't understand that",
}),
}),
}),
}),
})
# ---
# name: test_audio_pipeline.5
dict({
'engine': 'default',
'tts_input': "Sorry, I couldn't understand that",
})
# ---
# name: test_audio_pipeline.6
dict({
'tts_output': dict({
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en_-_demo.mp3',
}),
})
# ---
# name: test_intent_failed
dict({
'language': 'en-US',
'pipeline': 'en-US',
'runner_data': dict({
'stt_binary_handler_id': None,
}),
})
# ---
# name: test_intent_failed.1
dict({
'engine': 'default',
'intent_input': 'Are the lights on?',
})
# ---
# name: test_intent_timeout
dict({
'language': 'en-US',
'pipeline': 'en-US',
'runner_data': dict({
'stt_binary_handler_id': None,
}),
})
# ---
# name: test_intent_timeout.1
dict({
'engine': 'default',
'intent_input': 'Are the lights on?',
})
# ---
# name: test_intent_timeout.2
dict({
'code': 'timeout',
'message': 'Timeout running pipeline',
})
# ---
# name: test_stt_provider_missing
dict({
'language': 'en-US',
'pipeline': 'en-US',
'runner_data': dict({
'stt_binary_handler_id': 1,
}),
})
# ---
# name: test_stt_provider_missing.1
dict({
'engine': 'default',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'language': 'en-US',
'sample_rate': 16000,
}),
})
# ---
# name: test_stt_stream_failed
dict({
'language': 'en-US',
'pipeline': 'en-US',
'runner_data': dict({
'stt_binary_handler_id': 1,
}),
})
# ---
# name: test_stt_stream_failed.1
dict({
'engine': 'default',
'metadata': dict({
'bit_rate': 16,
'channel': 1,
'codec': 'pcm',
'format': 'wav',
'language': 'en-US',
'sample_rate': 16000,
}),
})
# ---
# name: test_text_only_pipeline
dict({
'language': 'en-US',
'pipeline': 'en-US',
'runner_data': dict({
'stt_binary_handler_id': None,
}),
})
# ---
# name: test_text_only_pipeline.1
dict({
'engine': 'default',
'intent_input': 'Are the lights on?',
})
# ---
# name: test_text_only_pipeline.2
dict({
'intent_output': dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'code': 'no_intent_match',
}),
'language': 'en-US',
'response_type': 'error',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': "Sorry, I couldn't understand that",
}),
}),
}),
}),
})
# ---
# name: test_text_pipeline_timeout
dict({
'code': 'timeout',
'message': 'Timeout running pipeline',
})
# ---
# name: test_tts_failed
dict({
'language': 'en-US',
'pipeline': 'en-US',
'runner_data': dict({
'stt_binary_handler_id': None,
}),
})
# ---
# name: test_tts_failed.1
dict({
'engine': 'default',
'tts_input': 'Lights are on.',
})
# ---

View File

@ -4,6 +4,7 @@ from collections.abc import AsyncIterable
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import stt from homeassistant.components import stt
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -29,7 +30,7 @@ class MockSttProvider(stt.Provider):
@property @property
def supported_languages(self) -> list[str]: def supported_languages(self) -> list[str]:
"""Return a list of supported languages.""" """Return a list of supported languages."""
return [self.hass.config.language] return ["en-US"]
@property @property
def supported_formats(self) -> list[stt.AudioFormats]: def supported_formats(self) -> list[stt.AudioFormats]:
@ -64,7 +65,11 @@ class MockSttProvider(stt.Provider):
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
async def init_components(hass): async def init_components(
hass: HomeAssistant,
mock_get_cache_files, # noqa: F811
mock_init_cache_dir, # noqa: F811
):
"""Initialize relevant components with empty configs.""" """Initialize relevant components with empty configs."""
assert await async_setup_component(hass, "media_source", {}) assert await async_setup_component(hass, "media_source", {})
assert await async_setup_component( assert await async_setup_component(
@ -93,6 +98,7 @@ async def init_components(hass):
async def test_text_only_pipeline( async def test_text_only_pipeline(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test events from a pipeline run with text input (no STT/TTS).""" """Test events from a pipeline run with text input (no STT/TTS)."""
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
@ -114,38 +120,16 @@ async def test_text_only_pipeline(
# run start # run start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == { assert msg["event"]["data"] == snapshot
"pipeline": hass.config.language,
"language": hass.config.language,
}
# intent # intent
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start" assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == { assert msg["event"]["data"] == snapshot
"engine": "default",
"intent_input": "Are the lights on?",
}
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "intent-end" assert msg["event"]["type"] == "intent-end"
assert msg["event"]["data"] == { assert msg["event"]["data"] == snapshot
"intent_output": {
"response": {
"speech": {
"plain": {
"speech": "Sorry, I couldn't understand that",
"extra_data": None,
}
},
"card": {},
"language": "en",
"response_type": "error",
"data": {"code": "no_intent_match"},
},
"conversation_id": None,
}
}
# run end # run end
msg = await client.receive_json() msg = await client.receive_json()
@ -154,8 +138,7 @@ async def test_text_only_pipeline(
async def test_audio_pipeline( async def test_audio_pipeline(
hass: HomeAssistant, hass: HomeAssistant, hass_ws_client: WebSocketGenerator, snapshot: SnapshotAssertion
hass_ws_client: WebSocketGenerator,
) -> None: ) -> None:
"""Test events from a pipeline run with audio input/output.""" """Test events from a pipeline run with audio input/output."""
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
@ -173,86 +156,40 @@ async def test_audio_pipeline(
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
# handler id
msg = await client.receive_json()
assert msg["event"]["handler_id"] == 1
# run start # run start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == { assert msg["event"]["data"] == snapshot
"pipeline": hass.config.language,
"language": hass.config.language,
}
# stt # stt
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "stt-start" assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == { assert msg["event"]["data"] == snapshot
"engine": "default",
"metadata": {
"bit_rate": 16,
"channel": 1,
"codec": "pcm",
"format": "wav",
"language": "en",
"sample_rate": 16000,
},
}
# End of audio stream (handler id + empty payload) # End of audio stream (handler id + empty payload)
await client.send_bytes(b"1") await client.send_bytes(b"1")
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "stt-end" assert msg["event"]["type"] == "stt-end"
assert msg["event"]["data"] == { assert msg["event"]["data"] == snapshot
"stt_output": {"text": _TRANSCRIPT},
}
# intent # intent
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start" assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == { assert msg["event"]["data"] == snapshot
"engine": "default",
"intent_input": _TRANSCRIPT,
}
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "intent-end" assert msg["event"]["type"] == "intent-end"
assert msg["event"]["data"] == { assert msg["event"]["data"] == snapshot
"intent_output": {
"response": {
"speech": {
"plain": {
"speech": "Sorry, I couldn't understand that",
"extra_data": None,
}
},
"card": {},
"language": "en",
"response_type": "error",
"data": {"code": "no_intent_match"},
},
"conversation_id": None,
}
}
# text to speech # text to speech
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "tts-start" assert msg["event"]["type"] == "tts-start"
assert msg["event"]["data"] == { assert msg["event"]["data"] == snapshot
"engine": "default",
"tts_input": "Sorry, I couldn't understand that",
}
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "tts-end" assert msg["event"]["type"] == "tts-end"
assert msg["event"]["data"] == { assert msg["event"]["data"] == snapshot
"tts_output": {
"url": f"/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_{hass.config.language}_-_demo.mp3",
"mime_type": "audio/mpeg",
},
}
# run end # run end
msg = await client.receive_json() msg = await client.receive_json()
@ -261,7 +198,10 @@ async def test_audio_pipeline(
async def test_intent_timeout( async def test_intent_timeout(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test partial pipeline run with conversation agent timeout.""" """Test partial pipeline run with conversation agent timeout."""
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
@ -291,27 +231,24 @@ async def test_intent_timeout(
# run start # run start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == { assert msg["event"]["data"] == snapshot
"pipeline": hass.config.language,
"language": hass.config.language,
}
# intent # intent
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start" assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == { assert msg["event"]["data"] == snapshot
"engine": "default",
"intent_input": "Are the lights on?",
}
# timeout error # timeout error
msg = await client.receive_json() msg = await client.receive_json()
assert not msg["success"] assert msg["event"]["type"] == "error"
assert msg["error"]["code"] == "timeout" assert msg["event"]["data"] == snapshot
async def test_text_pipeline_timeout( async def test_text_pipeline_timeout(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test text-only pipeline run with immediate timeout.""" """Test text-only pipeline run with immediate timeout."""
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
@ -340,12 +277,15 @@ async def test_text_pipeline_timeout(
# timeout error # timeout error
msg = await client.receive_json() msg = await client.receive_json()
assert not msg["success"] assert msg["event"]["type"] == "error"
assert msg["error"]["code"] == "timeout" assert msg["event"]["data"] == snapshot
async def test_intent_failed( async def test_intent_failed(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test text-only pipeline run with conversation agent error.""" """Test text-only pipeline run with conversation agent error."""
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
@ -371,18 +311,12 @@ async def test_intent_failed(
# run start # run start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == { assert msg["event"]["data"] == snapshot
"pipeline": hass.config.language,
"language": hass.config.language,
}
# intent start # intent start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start" assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == { assert msg["event"]["data"] == snapshot
"engine": "default",
"intent_input": "Are the lights on?",
}
# intent error # intent error
msg = await client.receive_json() msg = await client.receive_json()
@ -391,7 +325,10 @@ async def test_intent_failed(
async def test_audio_pipeline_timeout( async def test_audio_pipeline_timeout(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test audio pipeline run with immediate timeout.""" """Test audio pipeline run with immediate timeout."""
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
@ -417,19 +354,16 @@ async def test_audio_pipeline_timeout(
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
# handler id
msg = await client.receive_json()
assert msg["event"]["handler_id"] == 1
# timeout error # timeout error
msg = await client.receive_json() msg = await client.receive_json()
assert not msg["success"] assert msg["event"]["type"] == "error"
assert msg["error"]["code"] == "timeout" assert msg["event"]["data"]["code"] == "timeout"
async def test_stt_provider_missing( async def test_stt_provider_missing(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test events from a pipeline run with a non-existent STT provider.""" """Test events from a pipeline run with a non-existent STT provider."""
with patch( with patch(
@ -451,32 +385,15 @@ async def test_stt_provider_missing(
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
# handler id
msg = await client.receive_json()
assert msg["event"]["handler_id"] == 1
# run start # run start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == { assert msg["event"]["data"] == snapshot
"pipeline": hass.config.language,
"language": hass.config.language,
}
# stt # stt
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "stt-start" assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == { assert msg["event"]["data"] == snapshot
"engine": "default",
"metadata": {
"bit_rate": 16,
"channel": 1,
"codec": "pcm",
"format": "wav",
"language": "en",
"sample_rate": 16000,
},
}
# End of audio stream (handler id + empty payload) # End of audio stream (handler id + empty payload)
await client.send_bytes(b"1") await client.send_bytes(b"1")
@ -490,6 +407,7 @@ async def test_stt_provider_missing(
async def test_stt_stream_failed( async def test_stt_stream_failed(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test events from a pipeline run with a non-existent STT provider.""" """Test events from a pipeline run with a non-existent STT provider."""
with patch( with patch(
@ -511,32 +429,15 @@ async def test_stt_stream_failed(
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
# handler id
msg = await client.receive_json()
assert msg["event"]["handler_id"] == 1
# run start # run start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == { assert msg["event"]["data"] == snapshot
"pipeline": hass.config.language,
"language": hass.config.language,
}
# stt # stt
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "stt-start" assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == { assert msg["event"]["data"] == snapshot
"engine": "default",
"metadata": {
"bit_rate": 16,
"channel": 1,
"codec": "pcm",
"format": "wav",
"language": "en",
"sample_rate": 16000,
},
}
# End of audio stream (handler id + empty payload) # End of audio stream (handler id + empty payload)
await client.send_bytes(b"1") await client.send_bytes(b"1")
@ -548,7 +449,10 @@ async def test_stt_stream_failed(
async def test_tts_failed( async def test_tts_failed(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
snapshot: SnapshotAssertion,
) -> None: ) -> None:
"""Test pipeline run with text to speech error.""" """Test pipeline run with text to speech error."""
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
@ -574,18 +478,12 @@ async def test_tts_failed(
# run start # run start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == { assert msg["event"]["data"] == snapshot
"pipeline": hass.config.language,
"language": hass.config.language,
}
# tts start # tts start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "tts-start" assert msg["event"]["type"] == "tts-start"
assert msg["event"]["data"] == { assert msg["event"]["data"] == snapshot
"engine": "default",
"tts_input": "Lights are on.",
}
# tts error # tts error
msg = await client.receive_json() msg = await client.receive_json()