diff --git a/homeassistant/components/assist_pipeline/__init__.py b/homeassistant/components/assist_pipeline/__init__.py index 5467f9d2f29..1c28e023381 100644 --- a/homeassistant/components/assist_pipeline/__init__.py +++ b/homeassistant/components/assist_pipeline/__init__.py @@ -52,16 +52,13 @@ async def async_pipeline_from_audio_stream( tts_options: dict | None = None, ) -> None: """Create an audio pipeline from an audio stream.""" - if language is None: + if language is None and pipeline_id is None: language = hass.config.language # Temporary workaround for language codes if language == "en": language = "en-US" - if stt_metadata.language == "": - stt_metadata.language = language - if context is None: context = Context() @@ -75,6 +72,9 @@ async def async_pipeline_from_audio_stream( "pipeline_not_found", f"Pipeline {pipeline_id} not found" ) + if stt_metadata.language == "": + stt_metadata.language = pipeline.language + pipeline_input = PipelineInput( conversation_id=conversation_id, stt_metadata=stt_metadata, diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 9f5f28ec727..e41b358ce50 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -105,7 +105,7 @@ class Pipeline: """A voice assistant pipeline.""" conversation_engine: str | None - language: str | None + language: str name: str stt_engine: str | None tts_engine: str | None diff --git a/homeassistant/components/assist_pipeline/select.py b/homeassistant/components/assist_pipeline/select.py new file mode 100644 index 00000000000..275b46f7545 --- /dev/null +++ b/homeassistant/components/assist_pipeline/select.py @@ -0,0 +1,95 @@ +"""Select entities for a pipeline.""" + +from __future__ import annotations + +from collections.abc import Iterable + +from homeassistant.components.select import SelectEntity, SelectEntityDescription +from homeassistant.const import EntityCategory, Platform +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers import collection, entity_registry as er, restore_state + +from .const import DOMAIN +from .pipeline import PipelineStorageCollection + +OPTION_PREFERRED = "preferred" + + +@callback +def get_chosen_pipeline( + hass: HomeAssistant, domain: str, unique_id_prefix: str +) -> str | None: + """Get the chosen pipeline for a domain.""" + ent_reg = er.async_get(hass) + pipeline_entity_id = ent_reg.async_get_entity_id( + Platform.SELECT, domain, f"{unique_id_prefix}-pipeline" + ) + if pipeline_entity_id is None: + return None + + state = hass.states.get(pipeline_entity_id) + if state is None or state.state == OPTION_PREFERRED: + return None + + pipeline_store: PipelineStorageCollection = hass.data[DOMAIN] + return next( + (item.id for item in pipeline_store.async_items() if item.name == state.state), + None, + ) + + +class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity): + """Entity to represent a pipeline selector.""" + + entity_description = SelectEntityDescription( + key="pipeline", + translation_key="pipeline", + entity_category=EntityCategory.CONFIG, + ) + _attr_should_poll = False + _attr_current_option = OPTION_PREFERRED + _attr_options = [OPTION_PREFERRED] + + def __init__(self, hass: HomeAssistant, unique_id_prefix: str) -> None: + """Initialize a pipeline selector.""" + self._attr_unique_id = f"{unique_id_prefix}-pipeline" + self.hass = hass + self._update_options() + + async def async_added_to_hass(self) -> None: + """When entity is added to Home Assistant.""" + await super().async_added_to_hass() + + pipeline_store: PipelineStorageCollection = self.hass.data[ + DOMAIN + ].pipeline_store + pipeline_store.async_add_change_set_listener(self._pipelines_updated) + + state = await self.async_get_last_state() + if state is not None and state.state in self.options: + self._attr_current_option = state.state + + async def async_select_option(self, option: str) -> None: + """Select an option.""" + self._attr_current_option = option + self.async_write_ha_state() + + async def _pipelines_updated( + self, change_sets: Iterable[collection.CollectionChangeSet] + ) -> None: + """Handle pipeline update.""" + self._update_options() + self.async_write_ha_state() + + @callback + def _update_options(self) -> None: + """Handle pipeline update.""" + pipeline_store: PipelineStorageCollection = self.hass.data[ + DOMAIN + ].pipeline_store + options = [OPTION_PREFERRED] + options.extend(sorted(item.name for item in pipeline_store.async_items())) + self._attr_options = options + + if self._attr_current_option not in options: + self._attr_current_option = OPTION_PREFERRED diff --git a/homeassistant/components/assist_pipeline/strings.json b/homeassistant/components/assist_pipeline/strings.json new file mode 100644 index 00000000000..8ee0ad286b9 --- /dev/null +++ b/homeassistant/components/assist_pipeline/strings.json @@ -0,0 +1,12 @@ +{ + "entity": { + "select": { + "pipeline": { + "name": "Assist Pipeline", + "state": { + "preferred": "Preferred" + } + } + } + } +} diff --git a/homeassistant/components/voip/__init__.py b/homeassistant/components/voip/__init__.py index 9328555505d..07cdccdca56 100644 --- a/homeassistant/components/voip/__init__.py +++ b/homeassistant/components/voip/__init__.py @@ -19,6 +19,7 @@ from .voip import HassVoipDatagramProtocol PLATFORMS = ( Platform.BINARY_SENSOR, + Platform.SELECT, Platform.SWITCH, ) _LOGGER = logging.getLogger(__name__) diff --git a/homeassistant/components/voip/select.py b/homeassistant/components/voip/select.py new file mode 100644 index 00000000000..7383e1b886a --- /dev/null +++ b/homeassistant/components/voip/select.py @@ -0,0 +1,46 @@ +"""Select entities for VoIP integration.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from homeassistant.components.assist_pipeline.select import AssistPipelineSelect +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers.entity_platform import AddEntitiesCallback + +from .const import DOMAIN +from .devices import VoIPDevice +from .entity import VoIPEntity + +if TYPE_CHECKING: + from . import DomainData + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up VoIP switch entities.""" + domain_data: DomainData = hass.data[DOMAIN] + + @callback + def async_add_device(device: VoIPDevice) -> None: + """Add device.""" + async_add_entities([VoipPipelineSelect(hass, device)]) + + domain_data.devices.async_add_new_device_listener(async_add_device) + + async_add_entities( + [VoipPipelineSelect(hass, device) for device in domain_data.devices] + ) + + +class VoipPipelineSelect(VoIPEntity, AssistPipelineSelect): + """Pipeline selector for VoIP devices.""" + + def __init__(self, hass: HomeAssistant, device: VoIPDevice) -> None: + """Initialize a pipeline selector.""" + VoIPEntity.__init__(self, device) + AssistPipelineSelect.__init__(self, hass, device.voip_id) diff --git a/homeassistant/components/voip/strings.json b/homeassistant/components/voip/strings.json index dc3dd8a43cc..b483be494f2 100644 --- a/homeassistant/components/voip/strings.json +++ b/homeassistant/components/voip/strings.json @@ -24,6 +24,14 @@ "allow_call": { "name": "Allow Calls" } + }, + "select": { + "pipeline": { + "name": "[%key:component::assist_pipeline::entity::select::pipeline::name%]", + "state": { + "preferred": "[%key:component::assist_pipeline::entity::select::pipeline::state::preferred%]" + } + } } } } diff --git a/homeassistant/components/voip/voip.py b/homeassistant/components/voip/voip.py index 073ff690b7b..6bbdd38f2c2 100644 --- a/homeassistant/components/voip/voip.py +++ b/homeassistant/components/voip/voip.py @@ -15,11 +15,14 @@ from homeassistant.components.assist_pipeline import ( PipelineEvent, PipelineEventType, async_pipeline_from_audio_stream, + select as pipeline_select, ) from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter from homeassistant.const import __version__ from homeassistant.core import HomeAssistant +from .const import DOMAIN + if TYPE_CHECKING: from .devices import VoIPDevice, VoIPDevices @@ -151,7 +154,9 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol): channel=stt.AudioChannels.CHANNEL_MONO, ), stt_stream=stt_stream(), - language=self.language, + pipeline_id=pipeline_select.get_chosen_pipeline( + self.hass, DOMAIN, self.voip_device.voip_id + ), conversation_id=self._conversation_id, tts_options={tts.ATTR_AUDIO_OUTPUT: "raw"}, ) diff --git a/homeassistant/helpers/collection.py b/homeassistant/helpers/collection.py index 2526a210d70..cfc1750f7e5 100644 --- a/homeassistant/helpers/collection.py +++ b/homeassistant/helpers/collection.py @@ -143,20 +143,24 @@ class ObservableCollection(ABC, Generic[_ItemT]): return list(self.data.values()) @callback - def async_add_listener(self, listener: ChangeListener) -> None: + def async_add_listener(self, listener: ChangeListener) -> Callable[[], None]: """Add a listener. Will be called with (change_type, item_id, updated_config). """ self.listeners.append(listener) + return lambda: self.listeners.remove(listener) @callback - def async_add_change_set_listener(self, listener: ChangeSetListener) -> None: + def async_add_change_set_listener( + self, listener: ChangeSetListener + ) -> Callable[[], None]: """Add a listener for a full change set. Will be called with [(change_type, item_id, updated_config), ...] """ self.change_set_listeners.append(listener) + return lambda: self.change_set_listeners.remove(listener) async def notify_changes(self, change_sets: Iterable[CollectionChangeSet]) -> None: """Notify listeners of a change.""" diff --git a/tests/components/assist_pipeline/conftest.py b/tests/components/assist_pipeline/conftest.py index 06c78651bb6..b010236af09 100644 --- a/tests/components/assist_pipeline/conftest.py +++ b/tests/components/assist_pipeline/conftest.py @@ -6,6 +6,8 @@ from unittest.mock import AsyncMock, Mock import pytest from homeassistant.components import stt, tts +from homeassistant.components.assist_pipeline import DOMAIN +from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection from homeassistant.core import HomeAssistant from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.setup import async_setup_component @@ -137,3 +139,9 @@ async def init_components( assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}}) assert await async_setup_component(hass, "media_source", {}) assert await async_setup_component(hass, "assist_pipeline", {}) + + +@pytest.fixture +def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection: + """Return pipeline storage collection.""" + return hass.data[DOMAIN].pipeline_store diff --git a/tests/components/assist_pipeline/test_select.py b/tests/components/assist_pipeline/test_select.py new file mode 100644 index 00000000000..540ed98da13 --- /dev/null +++ b/tests/components/assist_pipeline/test_select.py @@ -0,0 +1,117 @@ +"""Test select entity.""" + +from __future__ import annotations + +import pytest + +from homeassistant.components.assist_pipeline import Pipeline +from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection +from homeassistant.components.assist_pipeline.select import AssistPipelineSelect +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity_platform import AddEntitiesCallback + +from tests.common import MockConfigEntry, MockPlatform, mock_entity_platform + + +class SelectPlatform(MockPlatform): + """Fake select platform.""" + + # pylint: disable=method-hidden + async def async_setup_entry( + self, + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, + ) -> None: + """Set up fake select platform.""" + async_add_entities([AssistPipelineSelect(hass, "test")]) + + +@pytest.fixture +async def init_select(hass: HomeAssistant, init_components) -> ConfigEntry: + """Initialize select entity.""" + mock_entity_platform(hass, "select.assist_pipeline", SelectPlatform()) + config_entry = MockConfigEntry(domain="assist_pipeline") + assert await hass.config_entries.async_forward_entry_setup(config_entry, "select") + return config_entry + + +@pytest.fixture +async def pipeline_1( + hass: HomeAssistant, init_select, pipeline_storage: PipelineStorageCollection +) -> Pipeline: + """Create a pipeline.""" + return await pipeline_storage.async_create_item( + { + "name": "Test 1", + "language": "en-US", + "conversation_engine": None, + "tts_engine": None, + "stt_engine": None, + } + ) + + +@pytest.fixture +async def pipeline_2( + hass: HomeAssistant, init_select, pipeline_storage: PipelineStorageCollection +) -> Pipeline: + """Create a pipeline.""" + return await pipeline_storage.async_create_item( + { + "name": "Test 2", + "language": "en-US", + "conversation_engine": None, + "tts_engine": None, + "stt_engine": None, + } + ) + + +async def test_select_entity_changing_pipelines( + hass: HomeAssistant, + init_select: ConfigEntry, + pipeline_1: Pipeline, + pipeline_2: Pipeline, + pipeline_storage: PipelineStorageCollection, +) -> None: + """Test entity tracking pipeline changes.""" + config_entry = init_select # nicer naming + + state = hass.states.get("select.assist_pipeline_test_pipeline") + assert state is not None + assert state.state == "preferred" + assert state.attributes["options"] == [ + "preferred", + pipeline_1.name, + pipeline_2.name, + ] + + # Change select to new pipeline + await hass.services.async_call( + "select", + "select_option", + { + "entity_id": "select.assist_pipeline_test_pipeline", + "option": pipeline_2.name, + }, + blocking=True, + ) + + state = hass.states.get("select.assist_pipeline_test_pipeline") + assert state.state == pipeline_2.name + + # Reload config entry to test selected option persists + assert await hass.config_entries.async_forward_entry_unload(config_entry, "select") + assert await hass.config_entries.async_forward_entry_setup(config_entry, "select") + + state = hass.states.get("select.assist_pipeline_test_pipeline") + assert state.state == pipeline_2.name + + # Remove selected pipeline + await pipeline_storage.async_delete_item(pipeline_2.id) + + state = hass.states.get("select.assist_pipeline_test_pipeline") + assert state.state == "preferred" + assert state.attributes["options"] == ["preferred", pipeline_1.name] diff --git a/tests/components/voip/test_select.py b/tests/components/voip/test_select.py new file mode 100644 index 00000000000..19c3202576a --- /dev/null +++ b/tests/components/voip/test_select.py @@ -0,0 +1,19 @@ +"""Test VoIP select.""" +from homeassistant.components.voip.devices import VoIPDevice +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant + + +async def test_pipeline_select( + hass: HomeAssistant, + config_entry: ConfigEntry, + voip_device: VoIPDevice, +) -> None: + """Test pipeline select. + + Functionality is tested in assist_pipeline/test_select.py. + This test is only to ensure it is set up. + """ + state = hass.states.get("select.192_168_1_210_assist_pipeline") + assert state is not None + assert state.state == "preferred"