mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 17:57:11 +00:00
Refactor cloud assist pipeline (#105723)
* Refactor cloud assist pipeline * Return None early
This commit is contained in:
parent
f4c8920231
commit
4da04a358a
44
homeassistant/components/cloud/assist_pipeline.py
Normal file
44
homeassistant/components/cloud/assist_pipeline.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
"""Handle Cloud assist pipelines."""
|
||||||
|
from homeassistant.components.assist_pipeline import (
|
||||||
|
async_create_default_pipeline,
|
||||||
|
async_get_pipelines,
|
||||||
|
async_setup_pipeline_store,
|
||||||
|
)
|
||||||
|
from homeassistant.components.conversation import HOME_ASSISTANT_AGENT
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
|
||||||
|
|
||||||
|
async def async_create_cloud_pipeline(hass: HomeAssistant) -> str | None:
|
||||||
|
"""Create a cloud assist pipeline."""
|
||||||
|
# Make sure the pipeline store is loaded, needed because assist_pipeline
|
||||||
|
# is an after dependency of cloud
|
||||||
|
await async_setup_pipeline_store(hass)
|
||||||
|
|
||||||
|
def cloud_assist_pipeline(hass: HomeAssistant) -> str | None:
|
||||||
|
"""Return the ID of a cloud-enabled assist pipeline or None.
|
||||||
|
|
||||||
|
Check if a cloud pipeline already exists with
|
||||||
|
legacy cloud engine id.
|
||||||
|
"""
|
||||||
|
for pipeline in async_get_pipelines(hass):
|
||||||
|
if (
|
||||||
|
pipeline.conversation_engine == HOME_ASSISTANT_AGENT
|
||||||
|
and pipeline.stt_engine == DOMAIN
|
||||||
|
and pipeline.tts_engine == DOMAIN
|
||||||
|
):
|
||||||
|
return pipeline.id
|
||||||
|
return None
|
||||||
|
|
||||||
|
if (cloud_assist_pipeline(hass)) is not None or (
|
||||||
|
cloud_pipeline := await async_create_default_pipeline(
|
||||||
|
hass,
|
||||||
|
stt_engine_id=DOMAIN,
|
||||||
|
tts_engine_id=DOMAIN,
|
||||||
|
pipeline_name="Home Assistant Cloud",
|
||||||
|
)
|
||||||
|
) is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return cloud_pipeline.id
|
@ -1,4 +1,6 @@
|
|||||||
"""The HTTP api to control the cloud integration."""
|
"""The HTTP api to control the cloud integration."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Awaitable, Callable, Coroutine, Mapping
|
from collections.abc import Awaitable, Callable, Coroutine, Mapping
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
@ -16,7 +18,7 @@ from hass_nabucasa.const import STATE_DISCONNECTED
|
|||||||
from hass_nabucasa.voice import MAP_VOICE
|
from hass_nabucasa.voice import MAP_VOICE
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import assist_pipeline, conversation, websocket_api
|
from homeassistant.components import websocket_api
|
||||||
from homeassistant.components.alexa import (
|
from homeassistant.components.alexa import (
|
||||||
entities as alexa_entities,
|
entities as alexa_entities,
|
||||||
errors as alexa_errors,
|
errors as alexa_errors,
|
||||||
@ -32,6 +34,7 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
|||||||
from homeassistant.util.location import async_detect_location_info
|
from homeassistant.util.location import async_detect_location_info
|
||||||
|
|
||||||
from .alexa_config import entity_supported as entity_supported_by_alexa
|
from .alexa_config import entity_supported as entity_supported_by_alexa
|
||||||
|
from .assist_pipeline import async_create_cloud_pipeline
|
||||||
from .client import CloudClient
|
from .client import CloudClient
|
||||||
from .const import (
|
from .const import (
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
@ -210,34 +213,11 @@ class CloudLoginView(HomeAssistantView):
|
|||||||
)
|
)
|
||||||
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
|
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
|
||||||
"""Handle login request."""
|
"""Handle login request."""
|
||||||
|
hass: HomeAssistant = request.app["hass"]
|
||||||
def cloud_assist_pipeline(hass: HomeAssistant) -> str | None:
|
cloud: Cloud[CloudClient] = hass.data[DOMAIN]
|
||||||
"""Return the ID of a cloud-enabled assist pipeline or None."""
|
|
||||||
for pipeline in assist_pipeline.async_get_pipelines(hass):
|
|
||||||
if (
|
|
||||||
pipeline.conversation_engine == conversation.HOME_ASSISTANT_AGENT
|
|
||||||
and pipeline.stt_engine == DOMAIN
|
|
||||||
and pipeline.tts_engine == DOMAIN
|
|
||||||
):
|
|
||||||
return pipeline.id
|
|
||||||
return None
|
|
||||||
|
|
||||||
hass = request.app["hass"]
|
|
||||||
cloud = hass.data[DOMAIN]
|
|
||||||
await cloud.login(data["email"], data["password"])
|
await cloud.login(data["email"], data["password"])
|
||||||
|
|
||||||
# Make sure the pipeline store is loaded, needed because assist_pipeline
|
new_cloud_pipeline_id = await async_create_cloud_pipeline(hass)
|
||||||
# is an after dependency of cloud
|
|
||||||
await assist_pipeline.async_setup_pipeline_store(hass)
|
|
||||||
new_cloud_pipeline_id: str | None = None
|
|
||||||
if (cloud_assist_pipeline(hass)) is None:
|
|
||||||
if cloud_pipeline := await assist_pipeline.async_create_default_pipeline(
|
|
||||||
hass,
|
|
||||||
stt_engine_id=DOMAIN,
|
|
||||||
tts_engine_id=DOMAIN,
|
|
||||||
pipeline_name="Home Assistant Cloud",
|
|
||||||
):
|
|
||||||
new_cloud_pipeline_id = cloud_pipeline.id
|
|
||||||
return self.json({"success": True, "cloud_pipeline": new_cloud_pipeline_id})
|
return self.json({"success": True, "cloud_pipeline": new_cloud_pipeline_id})
|
||||||
|
|
||||||
|
|
||||||
|
@ -150,7 +150,7 @@ async def test_login_view_existing_pipeline(
|
|||||||
cloud_client = await hass_client()
|
cloud_client = await hass_client()
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline",
|
"homeassistant.components.cloud.assist_pipeline.async_create_default_pipeline",
|
||||||
) as create_pipeline_mock:
|
) as create_pipeline_mock:
|
||||||
req = await cloud_client.post(
|
req = await cloud_client.post(
|
||||||
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
|
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
|
||||||
@ -183,7 +183,7 @@ async def test_login_view_create_pipeline(
|
|||||||
cloud_client = await hass_client()
|
cloud_client = await hass_client()
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline",
|
"homeassistant.components.cloud.assist_pipeline.async_create_default_pipeline",
|
||||||
return_value=AsyncMock(id="12345"),
|
return_value=AsyncMock(id="12345"),
|
||||||
) as create_pipeline_mock:
|
) as create_pipeline_mock:
|
||||||
req = await cloud_client.post(
|
req = await cloud_client.post(
|
||||||
@ -222,7 +222,7 @@ async def test_login_view_create_pipeline_fail(
|
|||||||
cloud_client = await hass_client()
|
cloud_client = await hass_client()
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline",
|
"homeassistant.components.cloud.assist_pipeline.async_create_default_pipeline",
|
||||||
return_value=None,
|
return_value=None,
|
||||||
) as create_pipeline_mock:
|
) as create_pipeline_mock:
|
||||||
req = await cloud_client.post(
|
req = await cloud_client.post(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user