mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Speech to Text component (#28434)
* Initial commit for STT * Fix code review
This commit is contained in:
parent
33c8cba30d
commit
99c0559a0c
@ -280,6 +280,7 @@ homeassistant/components/sql/* @dgomes
|
|||||||
homeassistant/components/statistics/* @fabaff
|
homeassistant/components/statistics/* @fabaff
|
||||||
homeassistant/components/stiebel_eltron/* @fucm
|
homeassistant/components/stiebel_eltron/* @fucm
|
||||||
homeassistant/components/stream/* @hunterjm
|
homeassistant/components/stream/* @hunterjm
|
||||||
|
homeassistant/components/stt/* @pvizeli
|
||||||
homeassistant/components/suez_water/* @ooii
|
homeassistant/components/suez_water/* @ooii
|
||||||
homeassistant/components/sun/* @Swamp-Ig
|
homeassistant/components/sun/* @Swamp-Ig
|
||||||
homeassistant/components/supla/* @mwegrzynek
|
homeassistant/components/supla/* @mwegrzynek
|
||||||
|
@ -25,6 +25,7 @@ COMPONENTS_WITH_DEMO_PLATFORM = [
|
|||||||
"media_player",
|
"media_player",
|
||||||
"notify",
|
"notify",
|
||||||
"sensor",
|
"sensor",
|
||||||
|
"stt",
|
||||||
"switch",
|
"switch",
|
||||||
"tts",
|
"tts",
|
||||||
"mailbox",
|
"mailbox",
|
||||||
|
60
homeassistant/components/demo/stt.py
Normal file
60
homeassistant/components/demo/stt.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
"""Support for the demo for speech to text service."""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from aiohttp import StreamReader
|
||||||
|
|
||||||
|
from homeassistant.components.stt import Provider, SpeechMetadata, SpeechResult
|
||||||
|
from homeassistant.components.stt.const import (
|
||||||
|
AudioBitrates,
|
||||||
|
AudioFormats,
|
||||||
|
AudioSamplerates,
|
||||||
|
AudioCodecs,
|
||||||
|
SpeechResultState,
|
||||||
|
)
|
||||||
|
|
||||||
|
SUPPORT_LANGUAGES = ["en", "de"]
|
||||||
|
|
||||||
|
|
||||||
|
async def async_get_engine(hass, config):
|
||||||
|
"""Set up Demo speech component."""
|
||||||
|
return DemoProvider()
|
||||||
|
|
||||||
|
|
||||||
|
class DemoProvider(Provider):
|
||||||
|
"""Demo speech API provider."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_languages(self) -> List[str]:
|
||||||
|
"""Return a list of supported languages."""
|
||||||
|
return SUPPORT_LANGUAGES
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_formats(self) -> List[AudioFormats]:
|
||||||
|
"""Return a list of supported formats."""
|
||||||
|
return [AudioFormats.WAV]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_codecs(self) -> List[AudioCodecs]:
|
||||||
|
"""Return a list of supported codecs."""
|
||||||
|
return [AudioCodecs.PCM]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_bitrates(self) -> List[AudioBitrates]:
|
||||||
|
"""Return a list of supported bitrates."""
|
||||||
|
return [AudioBitrates.BITRATE_16]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_samplerates(self) -> List[AudioSamplerates]:
|
||||||
|
"""Return a list of supported samplerates."""
|
||||||
|
return [AudioSamplerates.SAMPLERATE_16000, AudioSamplerates.SAMPLERATE_44100]
|
||||||
|
|
||||||
|
async def async_process_audio_stream(
|
||||||
|
self, metadata: SpeechMetadata, stream: StreamReader
|
||||||
|
) -> SpeechResult:
|
||||||
|
"""Process an audio stream to STT service."""
|
||||||
|
|
||||||
|
# Read available data
|
||||||
|
async for _ in stream.iter_chunked(4096):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return SpeechResult("Turn the Kitchen Lights on", SpeechResultState.SUCCESS)
|
@ -1,4 +1,4 @@
|
|||||||
"""Support for the demo speech service."""
|
"""Support for the demo for text to speech service."""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
217
homeassistant/components/stt/__init__.py
Normal file
217
homeassistant/components/stt/__init__.py
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
"""Provide functionality to STT."""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from aiohttp import StreamReader, web
|
||||||
|
from aiohttp.hdrs import istr
|
||||||
|
from aiohttp.web_exceptions import (
|
||||||
|
HTTPNotFound,
|
||||||
|
HTTPUnsupportedMediaType,
|
||||||
|
HTTPBadRequest,
|
||||||
|
)
|
||||||
|
import attr
|
||||||
|
|
||||||
|
from homeassistant.components.http import HomeAssistantView
|
||||||
|
from homeassistant.core import callback
|
||||||
|
from homeassistant.helpers import config_per_platform
|
||||||
|
from homeassistant.helpers.typing import HomeAssistantType
|
||||||
|
from homeassistant.setup import async_prepare_setup_platform
|
||||||
|
|
||||||
|
from .const import (
|
||||||
|
DOMAIN,
|
||||||
|
AudioBitrates,
|
||||||
|
AudioCodecs,
|
||||||
|
AudioFormats,
|
||||||
|
AudioSamplerates,
|
||||||
|
SpeechResultState,
|
||||||
|
)
|
||||||
|
|
||||||
|
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup(hass: HomeAssistantType, config):
|
||||||
|
"""Set up STT."""
|
||||||
|
providers = {}
|
||||||
|
|
||||||
|
async def async_setup_platform(p_type, p_config, disc_info=None):
|
||||||
|
"""Set up a TTS platform."""
|
||||||
|
platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type)
|
||||||
|
if platform is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider = await platform.async_get_engine(hass, p_config)
|
||||||
|
if provider is None:
|
||||||
|
_LOGGER.error("Error setting up platform %s", p_type)
|
||||||
|
return
|
||||||
|
|
||||||
|
provider.name = p_type
|
||||||
|
provider.hass = hass
|
||||||
|
|
||||||
|
providers[provider.name] = provider
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
_LOGGER.exception("Error setting up platform: %s", p_type)
|
||||||
|
return
|
||||||
|
|
||||||
|
setup_tasks = [
|
||||||
|
async_setup_platform(p_type, p_config)
|
||||||
|
for p_type, p_config in config_per_platform(config, DOMAIN)
|
||||||
|
]
|
||||||
|
|
||||||
|
if setup_tasks:
|
||||||
|
await asyncio.wait(setup_tasks)
|
||||||
|
|
||||||
|
hass.http.register_view(SpeechToTextView(providers))
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class SpeechMetadata:
|
||||||
|
"""Metadata of audio stream."""
|
||||||
|
|
||||||
|
language: str = attr.ib()
|
||||||
|
format: AudioFormats = attr.ib()
|
||||||
|
codec: AudioCodecs = attr.ib()
|
||||||
|
bitrate: AudioBitrates = attr.ib(converter=int)
|
||||||
|
samplerate: AudioSamplerates = attr.ib(converter=int)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class SpeechResult:
|
||||||
|
"""Result of audio Speech."""
|
||||||
|
|
||||||
|
text: str = attr.ib()
|
||||||
|
result: SpeechResultState = attr.ib()
|
||||||
|
|
||||||
|
|
||||||
|
class Provider(ABC):
|
||||||
|
"""Represent a single STT provider."""
|
||||||
|
|
||||||
|
hass: Optional[HomeAssistantType] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_languages(self) -> List[str]:
|
||||||
|
"""Return a list of supported languages."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_formats(self) -> List[AudioFormats]:
|
||||||
|
"""Return a list of supported formats."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_codecs(self) -> List[AudioCodecs]:
|
||||||
|
"""Return a list of supported codecs."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_bitrates(self) -> List[AudioBitrates]:
|
||||||
|
"""Return a list of supported bitrates."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_samplerates(self) -> List[AudioSamplerates]:
|
||||||
|
"""Return a list of supported samplerates."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def async_process_audio_stream(
|
||||||
|
self, metadata: SpeechMetadata, stream: StreamReader
|
||||||
|
) -> SpeechResult:
|
||||||
|
"""Process an audio stream to STT service.
|
||||||
|
|
||||||
|
Only streaming of content are allow!
|
||||||
|
"""
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def check_metadata(self, metadata: SpeechMetadata) -> bool:
|
||||||
|
"""Check if given metadata supported by this provider."""
|
||||||
|
if (
|
||||||
|
metadata.language not in self.supported_languages
|
||||||
|
or metadata.format not in self.supported_formats
|
||||||
|
or metadata.codec not in self.supported_codecs
|
||||||
|
or metadata.bitrate not in self.supported_bitrates
|
||||||
|
or metadata.samplerate not in self.supported_samplerates
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechToTextView(HomeAssistantView):
|
||||||
|
"""STT view to generate a text from audio stream."""
|
||||||
|
|
||||||
|
requires_auth = True
|
||||||
|
url = "/api/stt/{provider}"
|
||||||
|
name = "api:stt:provider"
|
||||||
|
|
||||||
|
def __init__(self, providers: Dict[str, Provider]) -> None:
|
||||||
|
"""Initialize a tts view."""
|
||||||
|
self.providers = providers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _metadata_from_header(request: web.Request) -> Optional[SpeechMetadata]:
|
||||||
|
"""Extract metadata from header.
|
||||||
|
|
||||||
|
X-Speech-Content: format=wav; codec=pcm; samplerate=16000; bitrate=16; language=de_de
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = request.headers[istr("X-Speech-Content")].split(";")
|
||||||
|
except KeyError:
|
||||||
|
_LOGGER.warning("Missing X-Speech-Content")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Convert Header data
|
||||||
|
args = dict()
|
||||||
|
for value in data:
|
||||||
|
value = value.strip()
|
||||||
|
args[value.partition("=")[0]] = value.partition("=")[2]
|
||||||
|
|
||||||
|
try:
|
||||||
|
return SpeechMetadata(**args)
|
||||||
|
except TypeError as err:
|
||||||
|
_LOGGER.warning("Wrong format of X-Speech-Content: %s", err)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def post(self, request: web.Request, provider: str) -> web.Response:
|
||||||
|
"""Convert Speech (audio) to text."""
|
||||||
|
if provider not in self.providers:
|
||||||
|
raise HTTPNotFound()
|
||||||
|
stt_provider: Provider = self.providers[provider]
|
||||||
|
|
||||||
|
# Get metadata
|
||||||
|
metadata = self._metadata_from_header(request)
|
||||||
|
if not metadata:
|
||||||
|
raise HTTPBadRequest()
|
||||||
|
|
||||||
|
# Check format
|
||||||
|
if not stt_provider.check_metadata(metadata):
|
||||||
|
raise HTTPUnsupportedMediaType()
|
||||||
|
|
||||||
|
# Process audio stream
|
||||||
|
result = await stt_provider.async_process_audio_stream(
|
||||||
|
metadata, request.content
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return result
|
||||||
|
return self.json(attr.asdict(result))
|
||||||
|
|
||||||
|
async def get(self, request: web.Request, provider: str) -> web.Response:
|
||||||
|
"""Return provider specific audio information."""
|
||||||
|
if provider not in self.providers:
|
||||||
|
raise HTTPNotFound()
|
||||||
|
stt_provider: Provider = self.providers[provider]
|
||||||
|
|
||||||
|
return self.json(
|
||||||
|
{
|
||||||
|
"languages": stt_provider.supported_languages,
|
||||||
|
"formats": stt_provider.supported_formats,
|
||||||
|
"codecs": stt_provider.supported_codecs,
|
||||||
|
"samplerates": stt_provider.supported_samplerates,
|
||||||
|
"bitrates": stt_provider.supported_bitrates,
|
||||||
|
}
|
||||||
|
)
|
48
homeassistant/components/stt/const.py
Normal file
48
homeassistant/components/stt/const.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
"""STT constante."""
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
DOMAIN = "stt"
|
||||||
|
|
||||||
|
|
||||||
|
class AudioCodecs(str, Enum):
|
||||||
|
"""Supported Audio codecs."""
|
||||||
|
|
||||||
|
PCM = "pcm"
|
||||||
|
OPUS = "opus"
|
||||||
|
|
||||||
|
|
||||||
|
class AudioFormats(str, Enum):
|
||||||
|
"""Supported Audio formats."""
|
||||||
|
|
||||||
|
WAV = "wav"
|
||||||
|
OGG = "ogg"
|
||||||
|
|
||||||
|
|
||||||
|
class AudioBitrates(int, Enum):
|
||||||
|
"""Supported Audio bitrates."""
|
||||||
|
|
||||||
|
BITRATE_8 = 8
|
||||||
|
BITRATE_16 = 16
|
||||||
|
BITRATE_24 = 24
|
||||||
|
BITRATE_32 = 32
|
||||||
|
|
||||||
|
|
||||||
|
class AudioSamplerates(int, Enum):
|
||||||
|
"""Supported Audio samplerates."""
|
||||||
|
|
||||||
|
SAMPLERATE_8000 = 8000
|
||||||
|
SAMPLERATE_11000 = 11000
|
||||||
|
SAMPLERATE_16000 = 16000
|
||||||
|
SAMPLERATE_18900 = 18900
|
||||||
|
SAMPLERATE_22000 = 22000
|
||||||
|
SAMPLERATE_32000 = 32000
|
||||||
|
SAMPLERATE_37800 = 37800
|
||||||
|
SAMPLERATE_44100 = 44100
|
||||||
|
SAMPLERATE_48000 = 48000
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechResultState(str, Enum):
|
||||||
|
"""Result state of speech."""
|
||||||
|
|
||||||
|
SUCCESS = "success"
|
||||||
|
ERROR = "error"
|
8
homeassistant/components/stt/manifest.json
Normal file
8
homeassistant/components/stt/manifest.json
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"domain": "stt",
|
||||||
|
"name": "Stt",
|
||||||
|
"documentation": "https://www.home-assistant.io/integrations/stt",
|
||||||
|
"requirements": [],
|
||||||
|
"dependencies": ["http"],
|
||||||
|
"codeowners": ["@pvizeli"]
|
||||||
|
}
|
1
homeassistant/components/stt/services.yaml
Normal file
1
homeassistant/components/stt/services.yaml
Normal file
@ -0,0 +1 @@
|
|||||||
|
# Describes the format for available STT services
|
69
tests/components/demo/test_stt.py
Normal file
69
tests/components/demo/test_stt.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
"""The tests for the demo stt component."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
from homeassistant.components import stt
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_comp(hass):
|
||||||
|
"""Set up demo component."""
|
||||||
|
hass.loop.run_until_complete(
|
||||||
|
async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "demo"}})
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_demo_settings(hass_client):
|
||||||
|
"""Test retrieve settings from demo provider."""
|
||||||
|
client = await hass_client()
|
||||||
|
|
||||||
|
response = await client.get("/api/stt/demo")
|
||||||
|
response_data = await response.json()
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert response_data == {
|
||||||
|
"languages": ["en", "de"],
|
||||||
|
"bitrates": [16],
|
||||||
|
"samplerates": [16000, 44100],
|
||||||
|
"formats": ["wav"],
|
||||||
|
"codecs": ["pcm"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_demo_speech_no_metadata(hass_client):
|
||||||
|
"""Test retrieve settings from demo provider."""
|
||||||
|
client = await hass_client()
|
||||||
|
|
||||||
|
response = await client.post("/api/stt/demo", data=b"Test")
|
||||||
|
assert response.status == 400
|
||||||
|
|
||||||
|
|
||||||
|
async def test_demo_speech_wrong_metadata(hass_client):
|
||||||
|
"""Test retrieve settings from demo provider."""
|
||||||
|
client = await hass_client()
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/stt/demo",
|
||||||
|
headers={
|
||||||
|
"X-Speech-Content": "format=wav; codec=pcm; samplerate=8000; bitrate=16; language=de"
|
||||||
|
},
|
||||||
|
data=b"Test",
|
||||||
|
)
|
||||||
|
assert response.status == 415
|
||||||
|
|
||||||
|
|
||||||
|
async def test_demo_speech(hass_client):
|
||||||
|
"""Test retrieve settings from demo provider."""
|
||||||
|
client = await hass_client()
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/stt/demo",
|
||||||
|
headers={
|
||||||
|
"X-Speech-Content": "format=wav; codec=pcm; samplerate=16000; bitrate=16; language=de"
|
||||||
|
},
|
||||||
|
data=b"Test",
|
||||||
|
)
|
||||||
|
response_data = await response.json()
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert response_data == {"text": "Turn the Kitchen Lights on", "result": "success"}
|
1
tests/components/stt/__init__.py
Normal file
1
tests/components/stt/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Speech to text tests."""
|
29
tests/components/stt/test_init.py
Normal file
29
tests/components/stt/test_init.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
"""Test STT component setup."""
|
||||||
|
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
from homeassistant.components import stt
|
||||||
|
|
||||||
|
|
||||||
|
async def test_setup_comp(hass):
|
||||||
|
"""Set up demo component."""
|
||||||
|
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}})
|
||||||
|
|
||||||
|
|
||||||
|
async def test_demo_settings_not_exists(hass, hass_client):
|
||||||
|
"""Test retrieve settings from demo provider."""
|
||||||
|
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}})
|
||||||
|
client = await hass_client()
|
||||||
|
|
||||||
|
response = await client.get("/api/stt/beer")
|
||||||
|
|
||||||
|
assert response.status == 404
|
||||||
|
|
||||||
|
|
||||||
|
async def test_demo_speech_not_exists(hass, hass_client):
|
||||||
|
"""Test retrieve settings from demo provider."""
|
||||||
|
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {}})
|
||||||
|
client = await hass_client()
|
||||||
|
|
||||||
|
response = await client.post("/api/stt/beer", data=b"test")
|
||||||
|
|
||||||
|
assert response.status == 404
|
Loading…
x
Reference in New Issue
Block a user