mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 12:17:07 +00:00
Enable strict typing of assist_pipeline (#91529)
This commit is contained in:
parent
9985516f80
commit
3367e86686
@ -61,6 +61,7 @@ homeassistant.components.anthemav.*
|
|||||||
homeassistant.components.apcupsd.*
|
homeassistant.components.apcupsd.*
|
||||||
homeassistant.components.aqualogic.*
|
homeassistant.components.aqualogic.*
|
||||||
homeassistant.components.aseko_pool_live.*
|
homeassistant.components.aseko_pool_live.*
|
||||||
|
homeassistant.components.assist_pipeline.*
|
||||||
homeassistant.components.asuswrt.*
|
homeassistant.components.asuswrt.*
|
||||||
homeassistant.components.auth.*
|
homeassistant.components.auth.*
|
||||||
homeassistant.components.automation.*
|
homeassistant.components.automation.*
|
||||||
|
@ -179,7 +179,7 @@ class PipelineRun:
|
|||||||
tts_engine: str | None = None
|
tts_engine: str | None = None
|
||||||
tts_options: dict | None = None
|
tts_options: dict | None = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
"""Set language for pipeline."""
|
"""Set language for pipeline."""
|
||||||
self.language = self.pipeline.language or self.hass.config.language
|
self.language = self.pipeline.language or self.hass.config.language
|
||||||
|
|
||||||
@ -189,7 +189,7 @@ class PipelineRun:
|
|||||||
):
|
):
|
||||||
raise InvalidPipelineStagesError(self.start_stage, self.end_stage)
|
raise InvalidPipelineStagesError(self.start_stage, self.end_stage)
|
||||||
|
|
||||||
def start(self):
|
def start(self) -> None:
|
||||||
"""Emit run start event."""
|
"""Emit run start event."""
|
||||||
data = {
|
data = {
|
||||||
"pipeline": self.pipeline.name,
|
"pipeline": self.pipeline.name,
|
||||||
@ -200,7 +200,7 @@ class PipelineRun:
|
|||||||
|
|
||||||
self.event_callback(PipelineEvent(PipelineEventType.RUN_START, data))
|
self.event_callback(PipelineEvent(PipelineEventType.RUN_START, data))
|
||||||
|
|
||||||
def end(self):
|
def end(self) -> None:
|
||||||
"""Emit run end event."""
|
"""Emit run end event."""
|
||||||
self.event_callback(
|
self.event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
@ -349,7 +349,9 @@ class PipelineRun:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
speech = conversation_result.response.speech.get("plain", {}).get("speech", "")
|
speech: str = conversation_result.response.speech.get("plain", {}).get(
|
||||||
|
"speech", ""
|
||||||
|
)
|
||||||
|
|
||||||
return speech
|
return speech
|
||||||
|
|
||||||
@ -453,7 +455,7 @@ class PipelineInput:
|
|||||||
|
|
||||||
conversation_id: str | None = None
|
conversation_id: str | None = None
|
||||||
|
|
||||||
async def execute(self):
|
async def execute(self) -> None:
|
||||||
"""Run pipeline."""
|
"""Run pipeline."""
|
||||||
self.run.start()
|
self.run.start()
|
||||||
current_stage = self.run.start_stage
|
current_stage = self.run.start_stage
|
||||||
@ -496,7 +498,7 @@ class PipelineInput:
|
|||||||
|
|
||||||
self.run.end()
|
self.run.end()
|
||||||
|
|
||||||
async def validate(self):
|
async def validate(self) -> None:
|
||||||
"""Validate pipeline input against start stage."""
|
"""Validate pipeline input against start stage."""
|
||||||
if self.run.start_stage == PipelineStage.STT:
|
if self.run.start_stage == PipelineStage.STT:
|
||||||
if self.stt_metadata is None:
|
if self.stt_metadata is None:
|
||||||
@ -524,7 +526,8 @@ class PipelineInput:
|
|||||||
prepare_tasks = []
|
prepare_tasks = []
|
||||||
|
|
||||||
if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.STT):
|
if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.STT):
|
||||||
prepare_tasks.append(self.run.prepare_speech_to_text(self.stt_metadata))
|
# self.stt_metadata can't be None or we'd raise above
|
||||||
|
prepare_tasks.append(self.run.prepare_speech_to_text(self.stt_metadata)) # type: ignore[arg-type]
|
||||||
|
|
||||||
if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.INTENT):
|
if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.INTENT):
|
||||||
prepare_tasks.append(self.run.prepare_recognize_intent())
|
prepare_tasks.append(self.run.prepare_recognize_intent())
|
||||||
@ -696,7 +699,7 @@ class PipelineStorageCollectionWebsocket(
|
|||||||
connection.send_result(msg["id"])
|
connection.send_result(msg["id"])
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_pipeline_store(hass):
|
async def async_setup_pipeline_store(hass: HomeAssistant) -> None:
|
||||||
"""Set up the pipeline storage collection."""
|
"""Set up the pipeline storage collection."""
|
||||||
pipeline_store = PipelineStorageCollection(
|
pipeline_store = PipelineStorageCollection(
|
||||||
Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||||
|
@ -48,14 +48,14 @@ class VoiceCommandSegmenter:
|
|||||||
_bytes_per_chunk: int = 480 * 2 # 16-bit samples
|
_bytes_per_chunk: int = 480 * 2 # 16-bit samples
|
||||||
_seconds_per_chunk: float = 0.03 # 30 ms
|
_seconds_per_chunk: float = 0.03 # 30 ms
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
"""Initialize VAD."""
|
"""Initialize VAD."""
|
||||||
self._vad = webrtcvad.Vad(self.vad_mode)
|
self._vad = webrtcvad.Vad(self.vad_mode)
|
||||||
self._bytes_per_chunk = self.vad_frames * 2
|
self._bytes_per_chunk = self.vad_frames * 2
|
||||||
self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE
|
self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self) -> None:
|
||||||
"""Reset all counters and state."""
|
"""Reset all counters and state."""
|
||||||
self._audio_buffer = b""
|
self._audio_buffer = b""
|
||||||
self._speech_seconds_left = self.speech_seconds
|
self._speech_seconds_left = self.speech_seconds
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Assist pipeline Websocket API."""
|
"""Assist pipeline Websocket API."""
|
||||||
import asyncio
|
import asyncio
|
||||||
import audioop # pylint: disable=deprecated-module
|
import audioop # pylint: disable=deprecated-module
|
||||||
from collections.abc import Callable
|
from collections.abc import AsyncGenerator, Callable
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -114,7 +114,7 @@ async def websocket_run(
|
|||||||
audio_queue: "asyncio.Queue[bytes]" = asyncio.Queue()
|
audio_queue: "asyncio.Queue[bytes]" = asyncio.Queue()
|
||||||
incoming_sample_rate = msg["input"]["sample_rate"]
|
incoming_sample_rate = msg["input"]["sample_rate"]
|
||||||
|
|
||||||
async def stt_stream():
|
async def stt_stream() -> AsyncGenerator[bytes, None]:
|
||||||
state = None
|
state = None
|
||||||
segmenter = VoiceCommandSegmenter()
|
segmenter = VoiceCommandSegmenter()
|
||||||
|
|
||||||
@ -129,7 +129,11 @@ async def websocket_run(
|
|||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
def handle_binary(_hass, _connection, data: bytes):
|
def handle_binary(
|
||||||
|
_hass: HomeAssistant,
|
||||||
|
_connection: websocket_api.ActiveConnection,
|
||||||
|
data: bytes,
|
||||||
|
) -> None:
|
||||||
# Forward to STT audio stream
|
# Forward to STT audio stream
|
||||||
audio_queue.put_nowait(data)
|
audio_queue.put_nowait(data)
|
||||||
|
|
||||||
|
10
mypy.ini
10
mypy.ini
@ -371,6 +371,16 @@ disallow_untyped_defs = true
|
|||||||
warn_return_any = true
|
warn_return_any = true
|
||||||
warn_unreachable = true
|
warn_unreachable = true
|
||||||
|
|
||||||
|
[mypy-homeassistant.components.assist_pipeline.*]
|
||||||
|
check_untyped_defs = true
|
||||||
|
disallow_incomplete_defs = true
|
||||||
|
disallow_subclassing_any = true
|
||||||
|
disallow_untyped_calls = true
|
||||||
|
disallow_untyped_decorators = true
|
||||||
|
disallow_untyped_defs = true
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unreachable = true
|
||||||
|
|
||||||
[mypy-homeassistant.components.asuswrt.*]
|
[mypy-homeassistant.components.asuswrt.*]
|
||||||
check_untyped_defs = true
|
check_untyped_defs = true
|
||||||
disallow_incomplete_defs = true
|
disallow_incomplete_defs = true
|
||||||
|
Loading…
x
Reference in New Issue
Block a user