mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 07:37:34 +00:00
Use HassKey for assist_pipeline singleton (#135875)
This commit is contained in:
parent
24c50e0988
commit
19e5b091c5
@ -50,6 +50,7 @@ from homeassistant.util import (
|
|||||||
language as language_util,
|
language as language_util,
|
||||||
ulid as ulid_util,
|
ulid as ulid_util,
|
||||||
)
|
)
|
||||||
|
from homeassistant.util.hass_dict import HassKey
|
||||||
from homeassistant.util.limited_size_dict import LimitedSizeDict
|
from homeassistant.util.limited_size_dict import LimitedSizeDict
|
||||||
|
|
||||||
from .audio_enhancer import AudioEnhancer, EnhancedAudioChunk, MicroVadSpeexEnhancer
|
from .audio_enhancer import AudioEnhancer, EnhancedAudioChunk, MicroVadSpeexEnhancer
|
||||||
@ -91,6 +92,8 @@ ENGINE_LANGUAGE_PAIRS = (
|
|||||||
("tts_engine", "tts_language"),
|
("tts_engine", "tts_language"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
KEY_ASSIST_PIPELINE: HassKey[PipelineData] = HassKey(DOMAIN)
|
||||||
|
|
||||||
|
|
||||||
def validate_language(data: dict[str, Any]) -> Any:
|
def validate_language(data: dict[str, Any]) -> Any:
|
||||||
"""Validate language settings."""
|
"""Validate language settings."""
|
||||||
@ -248,7 +251,7 @@ async def async_create_default_pipeline(
|
|||||||
The default pipeline will use the homeassistant conversation agent and the
|
The default pipeline will use the homeassistant conversation agent and the
|
||||||
specified stt / tts engines.
|
specified stt / tts engines.
|
||||||
"""
|
"""
|
||||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
|
||||||
pipeline_store = pipeline_data.pipeline_store
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
pipeline_settings = _async_resolve_default_pipeline_settings(
|
pipeline_settings = _async_resolve_default_pipeline_settings(
|
||||||
hass,
|
hass,
|
||||||
@ -283,7 +286,7 @@ def _async_get_pipeline_from_conversation_entity(
|
|||||||
@callback
|
@callback
|
||||||
def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> Pipeline:
|
def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> Pipeline:
|
||||||
"""Get a pipeline by id or the preferred pipeline."""
|
"""Get a pipeline by id or the preferred pipeline."""
|
||||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
|
||||||
|
|
||||||
if pipeline_id is None:
|
if pipeline_id is None:
|
||||||
# A pipeline was not specified, use the preferred one
|
# A pipeline was not specified, use the preferred one
|
||||||
@ -306,7 +309,7 @@ def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> P
|
|||||||
@callback
|
@callback
|
||||||
def async_get_pipelines(hass: HomeAssistant) -> list[Pipeline]:
|
def async_get_pipelines(hass: HomeAssistant) -> list[Pipeline]:
|
||||||
"""Get all pipelines."""
|
"""Get all pipelines."""
|
||||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
|
||||||
|
|
||||||
return list(pipeline_data.pipeline_store.data.values())
|
return list(pipeline_data.pipeline_store.data.values())
|
||||||
|
|
||||||
@ -329,7 +332,7 @@ async def async_update_pipeline(
|
|||||||
prefer_local_intents: bool | UndefinedType = UNDEFINED,
|
prefer_local_intents: bool | UndefinedType = UNDEFINED,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update a pipeline."""
|
"""Update a pipeline."""
|
||||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
|
||||||
|
|
||||||
updates: dict[str, Any] = pipeline.to_json()
|
updates: dict[str, Any] = pipeline.to_json()
|
||||||
updates.pop("id")
|
updates.pop("id")
|
||||||
@ -587,7 +590,7 @@ class PipelineRun:
|
|||||||
):
|
):
|
||||||
raise InvalidPipelineStagesError(self.start_stage, self.end_stage)
|
raise InvalidPipelineStagesError(self.start_stage, self.end_stage)
|
||||||
|
|
||||||
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE]
|
||||||
if self.pipeline.id not in pipeline_data.pipeline_debug:
|
if self.pipeline.id not in pipeline_data.pipeline_debug:
|
||||||
pipeline_data.pipeline_debug[self.pipeline.id] = LimitedSizeDict(
|
pipeline_data.pipeline_debug[self.pipeline.id] = LimitedSizeDict(
|
||||||
size_limit=STORED_PIPELINE_RUNS
|
size_limit=STORED_PIPELINE_RUNS
|
||||||
@ -615,7 +618,7 @@ class PipelineRun:
|
|||||||
def process_event(self, event: PipelineEvent) -> None:
|
def process_event(self, event: PipelineEvent) -> None:
|
||||||
"""Log an event and call listener."""
|
"""Log an event and call listener."""
|
||||||
self.event_callback(event)
|
self.event_callback(event)
|
||||||
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE]
|
||||||
if self.id not in pipeline_data.pipeline_debug[self.pipeline.id]:
|
if self.id not in pipeline_data.pipeline_debug[self.pipeline.id]:
|
||||||
# This run has been evicted from the logged pipeline runs already
|
# This run has been evicted from the logged pipeline runs already
|
||||||
return
|
return
|
||||||
@ -650,7 +653,7 @@ class PipelineRun:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE]
|
||||||
pipeline_data.pipeline_runs.remove_run(self)
|
pipeline_data.pipeline_runs.remove_run(self)
|
||||||
|
|
||||||
async def prepare_wake_word_detection(self) -> None:
|
async def prepare_wake_word_detection(self) -> None:
|
||||||
@ -1227,7 +1230,7 @@ class PipelineRun:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Forward to device audio capture
|
# Forward to device audio capture
|
||||||
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE]
|
||||||
audio_queue = pipeline_data.device_audio_queues.get(self._device_id)
|
audio_queue = pipeline_data.device_audio_queues.get(self._device_id)
|
||||||
if audio_queue is None:
|
if audio_queue is None:
|
||||||
return
|
return
|
||||||
@ -1884,7 +1887,7 @@ class PipelineStore(Store[SerializedPipelineStorageCollection]):
|
|||||||
return old_data
|
return old_data
|
||||||
|
|
||||||
|
|
||||||
@singleton(DOMAIN)
|
@singleton(KEY_ASSIST_PIPELINE, async_=True)
|
||||||
async def async_setup_pipeline_store(hass: HomeAssistant) -> PipelineData:
|
async def async_setup_pipeline_store(hass: HomeAssistant) -> PipelineData:
|
||||||
"""Set up the pipeline storage collection."""
|
"""Set up the pipeline storage collection."""
|
||||||
pipeline_store = PipelineStorageCollection(
|
pipeline_store = PipelineStorageCollection(
|
||||||
|
@ -9,8 +9,8 @@ from homeassistant.const import EntityCategory, Platform
|
|||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers import collection, entity_registry as er, restore_state
|
from homeassistant.helpers import collection, entity_registry as er, restore_state
|
||||||
|
|
||||||
from .const import DOMAIN, OPTION_PREFERRED
|
from .const import OPTION_PREFERRED
|
||||||
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
|
from .pipeline import KEY_ASSIST_PIPELINE, AssistDevice
|
||||||
from .vad import VadSensitivity
|
from .vad import VadSensitivity
|
||||||
|
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ def get_chosen_pipeline(
|
|||||||
if state is None or state.state == OPTION_PREFERRED:
|
if state is None or state.state == OPTION_PREFERRED:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
pipeline_store: PipelineStorageCollection = hass.data[DOMAIN].pipeline_store
|
pipeline_store = hass.data[KEY_ASSIST_PIPELINE].pipeline_store
|
||||||
return next(
|
return next(
|
||||||
(item.id for item in pipeline_store.async_items() if item.name == state.state),
|
(item.id for item in pipeline_store.async_items() if item.name == state.state),
|
||||||
None,
|
None,
|
||||||
@ -80,7 +80,7 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
|
|||||||
"""When entity is added to Home Assistant."""
|
"""When entity is added to Home Assistant."""
|
||||||
await super().async_added_to_hass()
|
await super().async_added_to_hass()
|
||||||
|
|
||||||
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
pipeline_data = self.hass.data[KEY_ASSIST_PIPELINE]
|
||||||
pipeline_store = pipeline_data.pipeline_store
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
self.async_on_remove(
|
self.async_on_remove(
|
||||||
pipeline_store.async_add_change_set_listener(self._pipelines_updated)
|
pipeline_store.async_add_change_set_listener(self._pipelines_updated)
|
||||||
@ -116,9 +116,7 @@ class AssistPipelineSelect(SelectEntity, restore_state.RestoreEntity):
|
|||||||
@callback
|
@callback
|
||||||
def _update_options(self) -> None:
|
def _update_options(self) -> None:
|
||||||
"""Handle pipeline update."""
|
"""Handle pipeline update."""
|
||||||
pipeline_store: PipelineStorageCollection = self.hass.data[
|
pipeline_store = self.hass.data[KEY_ASSIST_PIPELINE].pipeline_store
|
||||||
DOMAIN
|
|
||||||
].pipeline_store
|
|
||||||
options = [OPTION_PREFERRED]
|
options = [OPTION_PREFERRED]
|
||||||
options.extend(sorted(item.name for item in pipeline_store.async_items()))
|
options.extend(sorted(item.name for item in pipeline_store.async_items()))
|
||||||
self._attr_options = options
|
self._attr_options = options
|
||||||
|
@ -21,7 +21,6 @@ from homeassistant.util import language as language_util
|
|||||||
from .const import (
|
from .const import (
|
||||||
DEFAULT_PIPELINE_TIMEOUT,
|
DEFAULT_PIPELINE_TIMEOUT,
|
||||||
DEFAULT_WAKE_WORD_TIMEOUT,
|
DEFAULT_WAKE_WORD_TIMEOUT,
|
||||||
DOMAIN,
|
|
||||||
EVENT_RECORDING,
|
EVENT_RECORDING,
|
||||||
SAMPLE_CHANNELS,
|
SAMPLE_CHANNELS,
|
||||||
SAMPLE_RATE,
|
SAMPLE_RATE,
|
||||||
@ -29,9 +28,9 @@ from .const import (
|
|||||||
)
|
)
|
||||||
from .error import PipelineNotFound
|
from .error import PipelineNotFound
|
||||||
from .pipeline import (
|
from .pipeline import (
|
||||||
|
KEY_ASSIST_PIPELINE,
|
||||||
AudioSettings,
|
AudioSettings,
|
||||||
DeviceAudioQueue,
|
DeviceAudioQueue,
|
||||||
PipelineData,
|
|
||||||
PipelineError,
|
PipelineError,
|
||||||
PipelineEvent,
|
PipelineEvent,
|
||||||
PipelineEventType,
|
PipelineEventType,
|
||||||
@ -283,7 +282,7 @@ def websocket_list_runs(
|
|||||||
msg: dict[str, Any],
|
msg: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""List pipeline runs for which debug data is available."""
|
"""List pipeline runs for which debug data is available."""
|
||||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
|
||||||
pipeline_id = msg["pipeline_id"]
|
pipeline_id = msg["pipeline_id"]
|
||||||
|
|
||||||
if pipeline_id not in pipeline_data.pipeline_debug:
|
if pipeline_id not in pipeline_data.pipeline_debug:
|
||||||
@ -319,7 +318,7 @@ def websocket_list_devices(
|
|||||||
msg: dict[str, Any],
|
msg: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""List assist devices."""
|
"""List assist devices."""
|
||||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
|
||||||
ent_reg = er.async_get(hass)
|
ent_reg = er.async_get(hass)
|
||||||
connection.send_result(
|
connection.send_result(
|
||||||
msg["id"],
|
msg["id"],
|
||||||
@ -350,7 +349,7 @@ def websocket_get_run(
|
|||||||
msg: dict[str, Any],
|
msg: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Get debug data for a pipeline run."""
|
"""Get debug data for a pipeline run."""
|
||||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
|
||||||
pipeline_id = msg["pipeline_id"]
|
pipeline_id = msg["pipeline_id"]
|
||||||
pipeline_run_id = msg["pipeline_run_id"]
|
pipeline_run_id = msg["pipeline_run_id"]
|
||||||
|
|
||||||
@ -455,7 +454,7 @@ async def websocket_device_capture(
|
|||||||
msg: dict[str, Any],
|
msg: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Capture raw audio from a satellite device and forward to client."""
|
"""Capture raw audio from a satellite device and forward to client."""
|
||||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
pipeline_data = hass.data[KEY_ASSIST_PIPELINE]
|
||||||
device_id = msg["device_id"]
|
device_id = msg["device_id"]
|
||||||
|
|
||||||
# Number of seconds to record audio in wall clock time
|
# Number of seconds to record audio in wall clock time
|
||||||
|
Loading…
x
Reference in New Issue
Block a user