mirror of
https://github.com/home-assistant/core.git
synced 2025-07-17 10:17:09 +00:00
Allow picking a pipeline for voip devices (#91524)
* Allow picking a pipeline for voip device * Add tests * Fix test * Adjust on new pipeline data
This commit is contained in:
parent
9bd12f6503
commit
bd22e0bd43
@ -52,16 +52,13 @@ async def async_pipeline_from_audio_stream(
|
|||||||
tts_options: dict | None = None,
|
tts_options: dict | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create an audio pipeline from an audio stream."""
|
"""Create an audio pipeline from an audio stream."""
|
||||||
if language is None:
|
if language is None and pipeline_id is None:
|
||||||
language = hass.config.language
|
language = hass.config.language
|
||||||
|
|
||||||
# Temporary workaround for language codes
|
# Temporary workaround for language codes
|
||||||
if language == "en":
|
if language == "en":
|
||||||
language = "en-US"
|
language = "en-US"
|
||||||
|
|
||||||
if stt_metadata.language == "":
|
|
||||||
stt_metadata.language = language
|
|
||||||
|
|
||||||
if context is None:
|
if context is None:
|
||||||
context = Context()
|
context = Context()
|
||||||
|
|
||||||
@ -75,6 +72,9 @@ async def async_pipeline_from_audio_stream(
|
|||||||
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
|
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if stt_metadata.language == "":
|
||||||
|
stt_metadata.language = pipeline.language
|
||||||
|
|
||||||
pipeline_input = PipelineInput(
|
pipeline_input = PipelineInput(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
stt_metadata=stt_metadata,
|
stt_metadata=stt_metadata,
|
||||||
|
@ -105,7 +105,7 @@ class Pipeline:
|
|||||||
"""A voice assistant pipeline."""
|
"""A voice assistant pipeline."""
|
||||||
|
|
||||||
conversation_engine: str | None
|
conversation_engine: str | None
|
||||||
language: str | None
|
language: str
|
||||||
name: str
|
name: str
|
||||||
stt_engine: str | None
|
stt_engine: str | None
|
||||||
tts_engine: str | None
|
tts_engine: str | None
|
||||||
|
95
homeassistant/components/assist_pipeline/select.py
Normal file
95
homeassistant/components/assist_pipeline/select.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
"""Select entities for a pipeline."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
from homeassistant.components.select import SelectEntity, SelectEntityDescription
|
||||||
|
from homeassistant.const import EntityCategory, Platform
|
||||||
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
from homeassistant.helpers import collection, entity_registry as er, restore_state
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
from .pipeline import PipelineStorageCollection
|
||||||
|
|
||||||
|
OPTION_PREFERRED = "preferred"
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def get_chosen_pipeline(
|
||||||
|
hass: HomeAssistant, domain: str, unique_id_prefix: str
|
||||||
|
) -> str | None:
|
||||||
|
"""Get the chosen pipeline for a domain."""
|
||||||
|
ent_reg = er.async_get(hass)
|
||||||
|
pipeline_entity_id = ent_reg.async_get_entity_id(
|
||||||
|
Platform.SELECT, domain, f"{unique_id_prefix}-pipeline"
|
||||||
|
)
|
||||||
|
if pipeline_entity_id is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
state = hass.states.get(pipeline_entity_id)
|
||||||
|
if state is None or state.state == OPTION_PREFERRED:
|
||||||
|
return None
|
||||||
|
|
||||||
|
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN]
|
||||||
|
return next(
|
||||||
|
(item.id for item in pipeline_store.async_items() if item.name == state.state),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
|
||||||
|
"""Entity to represent a pipeline selector."""
|
||||||
|
|
||||||
|
entity_description = SelectEntityDescription(
|
||||||
|
key="pipeline",
|
||||||
|
translation_key="pipeline",
|
||||||
|
entity_category=EntityCategory.CONFIG,
|
||||||
|
)
|
||||||
|
_attr_should_poll = False
|
||||||
|
_attr_current_option = OPTION_PREFERRED
|
||||||
|
_attr_options = [OPTION_PREFERRED]
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant, unique_id_prefix: str) -> None:
|
||||||
|
"""Initialize a pipeline selector."""
|
||||||
|
self._attr_unique_id = f"{unique_id_prefix}-pipeline"
|
||||||
|
self.hass = hass
|
||||||
|
self._update_options()
|
||||||
|
|
||||||
|
async def async_added_to_hass(self) -> None:
|
||||||
|
"""When entity is added to Home Assistant."""
|
||||||
|
await super().async_added_to_hass()
|
||||||
|
|
||||||
|
pipeline_store: PipelineStorageCollection = self.hass.data[
|
||||||
|
DOMAIN
|
||||||
|
].pipeline_store
|
||||||
|
pipeline_store.async_add_change_set_listener(self._pipelines_updated)
|
||||||
|
|
||||||
|
state = await self.async_get_last_state()
|
||||||
|
if state is not None and state.state in self.options:
|
||||||
|
self._attr_current_option = state.state
|
||||||
|
|
||||||
|
async def async_select_option(self, option: str) -> None:
|
||||||
|
"""Select an option."""
|
||||||
|
self._attr_current_option = option
|
||||||
|
self.async_write_ha_state()
|
||||||
|
|
||||||
|
async def _pipelines_updated(
|
||||||
|
self, change_sets: Iterable[collection.CollectionChangeSet]
|
||||||
|
) -> None:
|
||||||
|
"""Handle pipeline update."""
|
||||||
|
self._update_options()
|
||||||
|
self.async_write_ha_state()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _update_options(self) -> None:
|
||||||
|
"""Handle pipeline update."""
|
||||||
|
pipeline_store: PipelineStorageCollection = self.hass.data[
|
||||||
|
DOMAIN
|
||||||
|
].pipeline_store
|
||||||
|
options = [OPTION_PREFERRED]
|
||||||
|
options.extend(sorted(item.name for item in pipeline_store.async_items()))
|
||||||
|
self._attr_options = options
|
||||||
|
|
||||||
|
if self._attr_current_option not in options:
|
||||||
|
self._attr_current_option = OPTION_PREFERRED
|
12
homeassistant/components/assist_pipeline/strings.json
Normal file
12
homeassistant/components/assist_pipeline/strings.json
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"entity": {
|
||||||
|
"select": {
|
||||||
|
"pipeline": {
|
||||||
|
"name": "Assist Pipeline",
|
||||||
|
"state": {
|
||||||
|
"preferred": "Preferred"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -19,6 +19,7 @@ from .voip import HassVoipDatagramProtocol
|
|||||||
|
|
||||||
PLATFORMS = (
|
PLATFORMS = (
|
||||||
Platform.BINARY_SENSOR,
|
Platform.BINARY_SENSOR,
|
||||||
|
Platform.SELECT,
|
||||||
Platform.SWITCH,
|
Platform.SWITCH,
|
||||||
)
|
)
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
46
homeassistant/components/voip/select.py
Normal file
46
homeassistant/components/voip/select.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
"""Select entities for VoIP integration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
from .devices import VoIPDevice
|
||||||
|
from .entity import VoIPEntity
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from . import DomainData
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up VoIP switch entities."""
|
||||||
|
domain_data: DomainData = hass.data[DOMAIN]
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_add_device(device: VoIPDevice) -> None:
|
||||||
|
"""Add device."""
|
||||||
|
async_add_entities([VoipPipelineSelect(hass, device)])
|
||||||
|
|
||||||
|
domain_data.devices.async_add_new_device_listener(async_add_device)
|
||||||
|
|
||||||
|
async_add_entities(
|
||||||
|
[VoipPipelineSelect(hass, device) for device in domain_data.devices]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VoipPipelineSelect(VoIPEntity, AssistPipelineSelect):
|
||||||
|
"""Pipeline selector for VoIP devices."""
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant, device: VoIPDevice) -> None:
|
||||||
|
"""Initialize a pipeline selector."""
|
||||||
|
VoIPEntity.__init__(self, device)
|
||||||
|
AssistPipelineSelect.__init__(self, hass, device.voip_id)
|
@ -24,6 +24,14 @@
|
|||||||
"allow_call": {
|
"allow_call": {
|
||||||
"name": "Allow Calls"
|
"name": "Allow Calls"
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"select": {
|
||||||
|
"pipeline": {
|
||||||
|
"name": "[%key:component::assist_pipeline::entity::select::pipeline::name%]",
|
||||||
|
"state": {
|
||||||
|
"preferred": "[%key:component::assist_pipeline::entity::select::pipeline::state::preferred%]"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,11 +15,14 @@ from homeassistant.components.assist_pipeline import (
|
|||||||
PipelineEvent,
|
PipelineEvent,
|
||||||
PipelineEventType,
|
PipelineEventType,
|
||||||
async_pipeline_from_audio_stream,
|
async_pipeline_from_audio_stream,
|
||||||
|
select as pipeline_select,
|
||||||
)
|
)
|
||||||
from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter
|
from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter
|
||||||
from homeassistant.const import __version__
|
from homeassistant.const import __version__
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .devices import VoIPDevice, VoIPDevices
|
from .devices import VoIPDevice, VoIPDevices
|
||||||
|
|
||||||
@ -151,7 +154,9 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
|||||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
),
|
),
|
||||||
stt_stream=stt_stream(),
|
stt_stream=stt_stream(),
|
||||||
language=self.language,
|
pipeline_id=pipeline_select.get_chosen_pipeline(
|
||||||
|
self.hass, DOMAIN, self.voip_device.voip_id
|
||||||
|
),
|
||||||
conversation_id=self._conversation_id,
|
conversation_id=self._conversation_id,
|
||||||
tts_options={tts.ATTR_AUDIO_OUTPUT: "raw"},
|
tts_options={tts.ATTR_AUDIO_OUTPUT: "raw"},
|
||||||
)
|
)
|
||||||
|
@ -143,20 +143,24 @@ class ObservableCollection(ABC, Generic[_ItemT]):
|
|||||||
return list(self.data.values())
|
return list(self.data.values())
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_add_listener(self, listener: ChangeListener) -> None:
|
def async_add_listener(self, listener: ChangeListener) -> Callable[[], None]:
|
||||||
"""Add a listener.
|
"""Add a listener.
|
||||||
|
|
||||||
Will be called with (change_type, item_id, updated_config).
|
Will be called with (change_type, item_id, updated_config).
|
||||||
"""
|
"""
|
||||||
self.listeners.append(listener)
|
self.listeners.append(listener)
|
||||||
|
return lambda: self.listeners.remove(listener)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_add_change_set_listener(self, listener: ChangeSetListener) -> None:
|
def async_add_change_set_listener(
|
||||||
|
self, listener: ChangeSetListener
|
||||||
|
) -> Callable[[], None]:
|
||||||
"""Add a listener for a full change set.
|
"""Add a listener for a full change set.
|
||||||
|
|
||||||
Will be called with [(change_type, item_id, updated_config), ...]
|
Will be called with [(change_type, item_id, updated_config), ...]
|
||||||
"""
|
"""
|
||||||
self.change_set_listeners.append(listener)
|
self.change_set_listeners.append(listener)
|
||||||
|
return lambda: self.change_set_listeners.remove(listener)
|
||||||
|
|
||||||
async def notify_changes(self, change_sets: Iterable[CollectionChangeSet]) -> None:
|
async def notify_changes(self, change_sets: Iterable[CollectionChangeSet]) -> None:
|
||||||
"""Notify listeners of a change."""
|
"""Notify listeners of a change."""
|
||||||
|
@ -6,6 +6,8 @@ from unittest.mock import AsyncMock, Mock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components import stt, tts
|
from homeassistant.components import stt, tts
|
||||||
|
from homeassistant.components.assist_pipeline import DOMAIN
|
||||||
|
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
@ -137,3 +139,9 @@ async def init_components(
|
|||||||
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
|
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
|
||||||
assert await async_setup_component(hass, "media_source", {})
|
assert await async_setup_component(hass, "media_source", {})
|
||||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection:
|
||||||
|
"""Return pipeline storage collection."""
|
||||||
|
return hass.data[DOMAIN].pipeline_store
|
||||||
|
117
tests/components/assist_pipeline/test_select.py
Normal file
117
tests/components/assist_pipeline/test_select.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
"""Test select entity."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components.assist_pipeline import Pipeline
|
||||||
|
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
|
||||||
|
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
|
||||||
|
from tests.common import MockConfigEntry, MockPlatform, mock_entity_platform
|
||||||
|
|
||||||
|
|
||||||
|
class SelectPlatform(MockPlatform):
|
||||||
|
"""Fake select platform."""
|
||||||
|
|
||||||
|
# pylint: disable=method-hidden
|
||||||
|
async def async_setup_entry(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up fake select platform."""
|
||||||
|
async_add_entities([AssistPipelineSelect(hass, "test")])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def init_select(hass: HomeAssistant, init_components) -> ConfigEntry:
|
||||||
|
"""Initialize select entity."""
|
||||||
|
mock_entity_platform(hass, "select.assist_pipeline", SelectPlatform())
|
||||||
|
config_entry = MockConfigEntry(domain="assist_pipeline")
|
||||||
|
assert await hass.config_entries.async_forward_entry_setup(config_entry, "select")
|
||||||
|
return config_entry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def pipeline_1(
|
||||||
|
hass: HomeAssistant, init_select, pipeline_storage: PipelineStorageCollection
|
||||||
|
) -> Pipeline:
|
||||||
|
"""Create a pipeline."""
|
||||||
|
return await pipeline_storage.async_create_item(
|
||||||
|
{
|
||||||
|
"name": "Test 1",
|
||||||
|
"language": "en-US",
|
||||||
|
"conversation_engine": None,
|
||||||
|
"tts_engine": None,
|
||||||
|
"stt_engine": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def pipeline_2(
|
||||||
|
hass: HomeAssistant, init_select, pipeline_storage: PipelineStorageCollection
|
||||||
|
) -> Pipeline:
|
||||||
|
"""Create a pipeline."""
|
||||||
|
return await pipeline_storage.async_create_item(
|
||||||
|
{
|
||||||
|
"name": "Test 2",
|
||||||
|
"language": "en-US",
|
||||||
|
"conversation_engine": None,
|
||||||
|
"tts_engine": None,
|
||||||
|
"stt_engine": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_select_entity_changing_pipelines(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_select: ConfigEntry,
|
||||||
|
pipeline_1: Pipeline,
|
||||||
|
pipeline_2: Pipeline,
|
||||||
|
pipeline_storage: PipelineStorageCollection,
|
||||||
|
) -> None:
|
||||||
|
"""Test entity tracking pipeline changes."""
|
||||||
|
config_entry = init_select # nicer naming
|
||||||
|
|
||||||
|
state = hass.states.get("select.assist_pipeline_test_pipeline")
|
||||||
|
assert state is not None
|
||||||
|
assert state.state == "preferred"
|
||||||
|
assert state.attributes["options"] == [
|
||||||
|
"preferred",
|
||||||
|
pipeline_1.name,
|
||||||
|
pipeline_2.name,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Change select to new pipeline
|
||||||
|
await hass.services.async_call(
|
||||||
|
"select",
|
||||||
|
"select_option",
|
||||||
|
{
|
||||||
|
"entity_id": "select.assist_pipeline_test_pipeline",
|
||||||
|
"option": pipeline_2.name,
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
state = hass.states.get("select.assist_pipeline_test_pipeline")
|
||||||
|
assert state.state == pipeline_2.name
|
||||||
|
|
||||||
|
# Reload config entry to test selected option persists
|
||||||
|
assert await hass.config_entries.async_forward_entry_unload(config_entry, "select")
|
||||||
|
assert await hass.config_entries.async_forward_entry_setup(config_entry, "select")
|
||||||
|
|
||||||
|
state = hass.states.get("select.assist_pipeline_test_pipeline")
|
||||||
|
assert state.state == pipeline_2.name
|
||||||
|
|
||||||
|
# Remove selected pipeline
|
||||||
|
await pipeline_storage.async_delete_item(pipeline_2.id)
|
||||||
|
|
||||||
|
state = hass.states.get("select.assist_pipeline_test_pipeline")
|
||||||
|
assert state.state == "preferred"
|
||||||
|
assert state.attributes["options"] == ["preferred", pipeline_1.name]
|
19
tests/components/voip/test_select.py
Normal file
19
tests/components/voip/test_select.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
"""Test VoIP select."""
|
||||||
|
from homeassistant.components.voip.devices import VoIPDevice
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_select(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
voip_device: VoIPDevice,
|
||||||
|
) -> None:
|
||||||
|
"""Test pipeline select.
|
||||||
|
|
||||||
|
Functionality is tested in assist_pipeline/test_select.py.
|
||||||
|
This test is only to ensure it is set up.
|
||||||
|
"""
|
||||||
|
state = hass.states.get("select.192_168_1_210_assist_pipeline")
|
||||||
|
assert state is not None
|
||||||
|
assert state.state == "preferred"
|
Loading…
x
Reference in New Issue
Block a user