diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 8979cae068e..a11b5a657de 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -50,6 +50,7 @@ from homeassistant.util import ( language as language_util, ulid as ulid_util, ) +from homeassistant.util.hass_dict import HassKey from homeassistant.util.limited_size_dict import LimitedSizeDict from .audio_enhancer import AudioEnhancer, EnhancedAudioChunk, MicroVadSpeexEnhancer @@ -91,6 +92,8 @@ ENGINE_LANGUAGE_PAIRS = ( ("tts_engine", "tts_language"), ) +KEY_ASSIST_PIPELINE: HassKey[PipelineData] = HassKey(DOMAIN) + def validate_language(data: dict[str, Any]) -> Any: """Validate language settings.""" @@ -248,7 +251,7 @@ async def async_create_default_pipeline( The default pipeline will use the homeassistant conversation agent and the specified stt / tts engines. """ - pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_data = hass.data[KEY_ASSIST_PIPELINE] pipeline_store = pipeline_data.pipeline_store pipeline_settings = _async_resolve_default_pipeline_settings( hass, @@ -283,7 +286,7 @@ def _async_get_pipeline_from_conversation_entity( @callback def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> Pipeline: """Get a pipeline by id or the preferred pipeline.""" - pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_data = hass.data[KEY_ASSIST_PIPELINE] if pipeline_id is None: # A pipeline was not specified, use the preferred one @@ -306,7 +309,7 @@ def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> P @callback def async_get_pipelines(hass: HomeAssistant) -> list[Pipeline]: """Get all pipelines.""" - pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_data = hass.data[KEY_ASSIST_PIPELINE] return list(pipeline_data.pipeline_store.data.values()) @@ -329,7 +332,7 @@ async def async_update_pipeline( prefer_local_intents: bool | UndefinedType = UNDEFINED, ) -> None: """Update a pipeline.""" - pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_data = hass.data[KEY_ASSIST_PIPELINE] updates: dict[str, Any] = pipeline.to_json() updates.pop("id") @@ -587,7 +590,7 @@ class PipelineRun: ): raise InvalidPipelineStagesError(self.start_stage, self.end_stage) - pipeline_data: PipelineData = self.hass.data[DOMAIN] + pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE] if self.pipeline.id not in pipeline_data.pipeline_debug: pipeline_data.pipeline_debug[self.pipeline.id] = LimitedSizeDict( size_limit=STORED_PIPELINE_RUNS @@ -615,7 +618,7 @@ class PipelineRun: def process_event(self, event: PipelineEvent) -> None: """Log an event and call listener.""" self.event_callback(event) - pipeline_data: PipelineData = self.hass.data[DOMAIN] + pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE] if self.id not in pipeline_data.pipeline_debug[self.pipeline.id]: # This run has been evicted from the logged pipeline runs already return @@ -650,7 +653,7 @@ class PipelineRun: ) ) - pipeline_data: PipelineData = self.hass.data[DOMAIN] + pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE] pipeline_data.pipeline_runs.remove_run(self) async def prepare_wake_word_detection(self) -> None: @@ -1227,7 +1230,7 @@ class PipelineRun: return # Forward to device audio capture - pipeline_data: PipelineData = self.hass.data[DOMAIN] + pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE] audio_queue = pipeline_data.device_audio_queues.get(self._device_id) if audio_queue is None: return @@ -1884,7 +1887,7 @@ class PipelineStore(Store[SerializedPipelineStorageCollection]): return old_data -@singleton(DOMAIN) +@singleton(KEY_ASSIST_PIPELINE, async_=True) async def async_setup_pipeline_store(hass: HomeAssistant) -> PipelineData: """Set up the pipeline storage collection.""" pipeline_store = PipelineStorageCollection( diff --git a/homeassistant/components/assist_pipeline/select.py b/homeassistant/components/assist_pipeline/select.py index c7e4846aad7..a590f30fc7a 100644 --- a/homeassistant/components/assist_pipeline/select.py +++ b/homeassistant/components/assist_pipeline/select.py @@ -9,8 +9,8 @@ from homeassistant.const import EntityCategory, Platform from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import collection, entity_registry as er, restore_state -from .const import DOMAIN, OPTION_PREFERRED -from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection +from .const import OPTION_PREFERRED +from .pipeline import KEY_ASSIST_PIPELINE, AssistDevice from .vad import VadSensitivity @@ -30,7 +30,7 @@ def get_chosen_pipeline( if state is None or state.state == OPTION_PREFERRED: return None - pipeline_store: PipelineStorageCollection = hass.data[DOMAIN].pipeline_store + pipeline_store = hass.data[KEY_ASSIST_PIPELINE].pipeline_store return next( (item.id for item in pipeline_store.async_items() if item.name == state.state), None, @@ -80,7 +80,7 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity): """When entity is added to Home Assistant.""" await super().async_added_to_hass() - pipeline_data: PipelineData = self.hass.data[DOMAIN] + pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE] pipeline_store = pipeline_data.pipeline_store self.async_on_remove( pipeline_store.async_add_change_set_listener(self._pipelines_updated) @@ -116,9 +116,7 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity): @callback def _update_options(self) -> None: """Handle pipeline update.""" - pipeline_store: PipelineStorageCollection = self.hass.data[ - DOMAIN - ].pipeline_store + pipeline_store = self.hass.data[KEY_ASSIST_PIPELINE].pipeline_store options = [OPTION_PREFERRED] options.extend(sorted(item.name for item in pipeline_store.async_items())) self._attr_options = options diff --git a/homeassistant/components/assist_pipeline/websocket_api.py b/homeassistant/components/assist_pipeline/websocket_api.py index d61580f4a14..e8da8e56fd6 100644 --- a/homeassistant/components/assist_pipeline/websocket_api.py +++ b/homeassistant/components/assist_pipeline/websocket_api.py @@ -21,7 +21,6 @@ from homeassistant.util import language as language_util from .const import ( DEFAULT_PIPELINE_TIMEOUT, DEFAULT_WAKE_WORD_TIMEOUT, - DOMAIN, EVENT_RECORDING, SAMPLE_CHANNELS, SAMPLE_RATE, @@ -29,9 +28,9 @@ from .const import ( ) from .error import PipelineNotFound from .pipeline import ( + KEY_ASSIST_PIPELINE, AudioSettings, DeviceAudioQueue, - PipelineData, PipelineError, PipelineEvent, PipelineEventType, @@ -283,7 +282,7 @@ def websocket_list_runs( msg: dict[str, Any], ) -> None: """List pipeline runs for which debug data is available.""" - pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_data = hass.data[KEY_ASSIST_PIPELINE] pipeline_id = msg["pipeline_id"] if pipeline_id not in pipeline_data.pipeline_debug: @@ -319,7 +318,7 @@ def websocket_list_devices( msg: dict[str, Any], ) -> None: """List assist devices.""" - pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_data = hass.data[KEY_ASSIST_PIPELINE] ent_reg = er.async_get(hass) connection.send_result( msg["id"], @@ -350,7 +349,7 @@ def websocket_get_run( msg: dict[str, Any], ) -> None: """Get debug data for a pipeline run.""" - pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_data = hass.data[KEY_ASSIST_PIPELINE] pipeline_id = msg["pipeline_id"] pipeline_run_id = msg["pipeline_run_id"] @@ -455,7 +454,7 @@ async def websocket_device_capture( msg: dict[str, Any], ) -> None: """Capture raw audio from a satellite device and forward to client.""" - pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_data = hass.data[KEY_ASSIST_PIPELINE] device_id = msg["device_id"] # Number of seconds to record audio in wall clock time