Assist pipeline to use configured values (#91901)

* Assist pipeline to use configured values

* Include voice in TTS-START event

* Use correct tts language var

* More vars

* Apply suggestions from code review

* Update

---------

Co-authored-by: Bram Kragten <mail@bramkragten.nl>
This commit is contained in:
Paulus Schoutsen 2023-04-23 12:48:11 -04:00 committed by GitHub
parent ec1952b926
commit f4df0ca50a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 48 additions and 26 deletions

View File

@ -57,9 +57,6 @@ async def async_pipeline_from_audio_stream(
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
)
if stt_metadata.language == "":
stt_metadata.language = pipeline.language
pipeline_input = PipelineInput(
conversation_id=conversation_id,
stt_metadata=stt_metadata,

View File

@ -282,12 +282,14 @@ class PipelineRun:
message=f"No speech to text provider for: {engine}",
)
metadata.language = self.pipeline.stt_language or self.language
if not stt_provider.check_metadata(metadata):
raise SpeechToTextError(
code="stt-provider-unsupported-metadata",
message=(
f"Provider {stt_provider.name} does not support input speech "
"to text metadata"
"to text metadata {metadata}"
),
)
@ -382,6 +384,7 @@ class PipelineRun:
PipelineEventType.INTENT_START,
{
"engine": self.intent_agent,
"language": self.pipeline.conversation_language,
"intent_input": intent_input,
},
)
@ -393,7 +396,7 @@ class PipelineRun:
text=intent_input,
conversation_id=conversation_id,
context=self.context,
language=self.language,
language=self.pipeline.conversation_language,
agent_id=self.intent_agent,
)
except Exception as src_error:
@ -439,14 +442,14 @@ class PipelineRun:
if not await tts.async_support_options(
self.hass,
engine,
self.language,
self.pipeline.tts_language,
tts_options,
):
raise TextToSpeechError(
code="tts-not-supported",
message=(
f"Text to speech engine {engine} "
f"does not support language {self.language} or options {tts_options}"
f"does not support language {self.pipeline.tts_language} or options {tts_options}"
),
)
@ -463,6 +466,8 @@ class PipelineRun:
PipelineEventType.TTS_START,
{
"engine": self.tts_engine,
"language": self.pipeline.tts_language,
"voice": self.pipeline.tts_voice,
"tts_input": tts_input,
},
)
@ -474,7 +479,7 @@ class PipelineRun:
self.hass,
tts_input,
engine=self.tts_engine,
language=self.language,
language=self.pipeline.tts_language,
options=self.tts_options,
)
tts_media = await media_source.async_resolve_media(

View File

@ -137,7 +137,7 @@ async def websocket_run(
# Audio input must be raw PCM at 16Khz with 16-bit mono samples
input_args["stt_metadata"] = stt.SpeechMetadata(
language=pipeline.language,
language=pipeline.stt_language or pipeline.language,
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,

View File

@ -34,6 +34,7 @@
'data': dict({
'engine': 'homeassistant',
'intent_input': 'test transcript',
'language': None,
}),
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
}),
@ -63,14 +64,16 @@
dict({
'data': dict({
'engine': 'test',
'language': None,
'tts_input': "Sorry, I couldn't understand that",
'voice': None,
}),
'type': <PipelineEventType.TTS_START: 'tts-start'>,
}),
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
}),
@ -87,7 +90,7 @@
list([
dict({
'data': dict({
'language': 'en-US',
'language': 'en',
'pipeline': 'test_name',
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
@ -118,6 +121,7 @@
'data': dict({
'engine': 'homeassistant',
'intent_input': 'test transcript',
'language': 'en-US',
}),
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
}),
@ -147,14 +151,16 @@
dict({
'data': dict({
'engine': 'test',
'language': 'en-UA',
'tts_input': "Sorry, I couldn't understand that",
'voice': 'Arnold Schwarzenegger',
}),
'type': <PipelineEventType.TTS_START: 'tts-start'>,
}),
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=Arnold+Schwarzenegger",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-UA&voice=Arnold+Schwarzenegger",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_2657c1a8ee_test.mp3',
}),
@ -171,7 +177,7 @@
list([
dict({
'data': dict({
'language': 'en-US',
'language': 'en',
'pipeline': 'test_name',
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
@ -202,6 +208,7 @@
'data': dict({
'engine': 'homeassistant',
'intent_input': 'test transcript',
'language': 'en-US',
}),
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
}),
@ -231,14 +238,16 @@
dict({
'data': dict({
'engine': 'test',
'language': 'en-AU',
'tts_input': "Sorry, I couldn't understand that",
'voice': 'Arnold Schwarzenegger',
}),
'type': <PipelineEventType.TTS_START: 'tts-start'>,
}),
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=Arnold+Schwarzenegger",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-AU&voice=Arnold+Schwarzenegger",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_2657c1a8ee_test.mp3',
}),

View File

@ -33,6 +33,7 @@
dict({
'engine': 'homeassistant',
'intent_input': 'test transcript',
'language': None,
})
# ---
# name: test_audio_pipeline.4
@ -60,13 +61,15 @@
# name: test_audio_pipeline.5
dict({
'engine': 'test',
'language': None,
'tts_input': "Sorry, I couldn't understand that",
'voice': None,
})
# ---
# name: test_audio_pipeline.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
}),
@ -106,6 +109,7 @@
dict({
'engine': 'homeassistant',
'intent_input': 'test transcript',
'language': None,
})
# ---
# name: test_audio_pipeline_debug.4
@ -133,13 +137,15 @@
# name: test_audio_pipeline_debug.5
dict({
'engine': 'test',
'language': None,
'tts_input': "Sorry, I couldn't understand that",
'voice': None,
})
# ---
# name: test_audio_pipeline_debug.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
}),
@ -159,6 +165,7 @@
dict({
'engine': 'homeassistant',
'intent_input': 'Are the lights on?',
'language': None,
})
# ---
# name: test_intent_timeout
@ -175,6 +182,7 @@
dict({
'engine': 'homeassistant',
'intent_input': 'Are the lights on?',
'language': None,
})
# ---
# name: test_intent_timeout.2
@ -243,6 +251,7 @@
dict({
'engine': 'homeassistant',
'intent_input': 'Are the lights on?',
'language': None,
})
# ---
# name: test_text_only_pipeline.2
@ -286,6 +295,8 @@
# name: test_tts_failed.1
dict({
'engine': 'test',
'language': None,
'tts_input': 'Lights are on.',
'voice': None,
})
# ---

View File

@ -79,13 +79,13 @@ async def test_pipeline_from_audio_stream_legacy(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "homeassistant",
"conversation_language": "test_language",
"language": "en-US",
"conversation_language": "en-US",
"language": "en",
"name": "test_name",
"stt_engine": "test",
"stt_language": "test_language",
"stt_language": "en-UK",
"tts_engine": "test",
"tts_language": "test_language",
"tts_language": "en-AU",
"tts_voice": "Arnold Schwarzenegger",
}
)
@ -99,7 +99,7 @@ async def test_pipeline_from_audio_stream_legacy(
Context(),
events.append,
stt.SpeechMetadata(
language="",
language="en-UK",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
@ -145,13 +145,13 @@ async def test_pipeline_from_audio_stream_entity(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "homeassistant",
"conversation_language": "test_language",
"language": "en-US",
"conversation_language": "en-US",
"language": "en",
"name": "test_name",
"stt_engine": mock_stt_provider_entity.entity_id,
"stt_language": "test_language",
"stt_language": "en-UK",
"tts_engine": "test",
"tts_language": "test_language",
"tts_language": "en-UA",
"tts_voice": "Arnold Schwarzenegger",
}
)
@ -165,7 +165,7 @@ async def test_pipeline_from_audio_stream_entity(
Context(),
events.append,
stt.SpeechMetadata(
language="",
language="en-UK",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,