mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 01:38:02 +00:00
HassTurnOn/Off intents to also handle cover entities (#86206)
* Move entity/area resolution to async_match_states * Special case for covers in HassTurnOn/Off * Enable light color/brightness on areas * Remove async_register from default agent * Remove CONFIG_SCHEMA from conversation component * Fix intent tests * Fix light test * Move entity/area resolution to async_match_states * Special case for covers in HassTurnOn/Off * Enable light color/brightness on areas * Remove async_register from default agent * Remove CONFIG_SCHEMA from conversation component * Fix intent tests * Fix light test * Fix humidifier intent handlers * Remove DATA_CONFIG for conversation * Copy ServiceIntentHandler code to light * Add proper errors to humidifier intent handlers
This commit is contained in:
parent
8f10c22a23
commit
5aca996f22
@ -16,7 +16,7 @@ from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import bind_hass
|
||||
|
||||
from .agent import AbstractConversationAgent, ConversationResult
|
||||
from .default_agent import DefaultAgent, async_register
|
||||
from .default_agent import DefaultAgent
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -27,7 +27,6 @@ DOMAIN = "conversation"
|
||||
|
||||
REGEX_TYPE = type(re.compile(""))
|
||||
DATA_AGENT = "conversation_agent"
|
||||
DATA_CONFIG = "conversation_config"
|
||||
|
||||
SERVICE_PROCESS = "process"
|
||||
SERVICE_RELOAD = "reload"
|
||||
@ -47,22 +46,6 @@ SERVICE_RELOAD_SCHEMA = vol.Schema(
|
||||
)
|
||||
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema(
|
||||
{
|
||||
DOMAIN: vol.Schema(
|
||||
{
|
||||
vol.Optional("intents"): vol.Schema(
|
||||
{cv.string: vol.All(cv.ensure_list, [cv.string])}
|
||||
)
|
||||
}
|
||||
)
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
)
|
||||
|
||||
async_register = bind_hass(async_register)
|
||||
|
||||
|
||||
@core.callback
|
||||
@bind_hass
|
||||
def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent | None):
|
||||
@ -72,7 +55,6 @@ def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent |
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Register the process service."""
|
||||
hass.data[DATA_CONFIG] = config
|
||||
|
||||
async def handle_process(service: core.ServiceCall) -> None:
|
||||
"""Parse text into commands."""
|
||||
@ -228,7 +210,7 @@ async def _get_agent(hass: core.HomeAssistant) -> AbstractConversationAgent:
|
||||
"""Get the active conversation agent."""
|
||||
if (agent := hass.data.get(DATA_AGENT)) is None:
|
||||
agent = hass.data[DATA_AGENT] = DefaultAgent(hass)
|
||||
await agent.async_initialize(hass.data.get(DATA_CONFIG))
|
||||
await agent.async_initialize()
|
||||
return agent
|
||||
|
||||
|
||||
|
@ -20,30 +20,12 @@ from homeassistant.helpers import area_registry, entity_registry, intent
|
||||
|
||||
from .agent import AbstractConversationAgent, ConversationResult
|
||||
from .const import DOMAIN
|
||||
from .util import create_matcher
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
REGEX_TYPE = type(re.compile(""))
|
||||
|
||||
|
||||
@core.callback
|
||||
def async_register(hass, intent_type, utterances):
|
||||
"""Register utterances and any custom intents for the default agent.
|
||||
|
||||
Registrations don't require conversations to be loaded. They will become
|
||||
active once the conversation component is loaded.
|
||||
"""
|
||||
intents = hass.data.setdefault(DOMAIN, {})
|
||||
conf = intents.setdefault(intent_type, [])
|
||||
|
||||
for utterance in utterances:
|
||||
if isinstance(utterance, REGEX_TYPE):
|
||||
conf.append(utterance)
|
||||
else:
|
||||
conf.append(create_matcher(utterance))
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageIntents:
|
||||
"""Loaded intents for a language."""
|
||||
@ -62,16 +44,11 @@ class DefaultAgent(AbstractConversationAgent):
|
||||
self._lang_intents: dict[str, LanguageIntents] = {}
|
||||
self._lang_lock: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
|
||||
async def async_initialize(self, config):
|
||||
async def async_initialize(self):
|
||||
"""Initialize the default agent."""
|
||||
if "intent" not in self.hass.config.components:
|
||||
await setup.async_setup_component(self.hass, "intent", {})
|
||||
|
||||
if config and config.get(DOMAIN):
|
||||
_LOGGER.warning(
|
||||
"Custom intent sentences have been moved to config/custom_sentences"
|
||||
)
|
||||
|
||||
self.hass.data.setdefault(DOMAIN, {})
|
||||
|
||||
async def async_process(
|
||||
|
@ -41,10 +41,18 @@ class HumidityHandler(intent.IntentHandler):
|
||||
"""Handle the hass intent."""
|
||||
hass = intent_obj.hass
|
||||
slots = self.async_validate_slots(intent_obj.slots)
|
||||
state = intent.async_match_state(
|
||||
hass, slots["name"]["value"], hass.states.async_all(DOMAIN)
|
||||
states = list(
|
||||
intent.async_match_states(
|
||||
hass,
|
||||
name=slots["name"]["value"],
|
||||
states=hass.states.async_all(DOMAIN),
|
||||
)
|
||||
)
|
||||
|
||||
if not states:
|
||||
raise intent.IntentHandleError("No entities matched")
|
||||
|
||||
state = states[0]
|
||||
service_data = {ATTR_ENTITY_ID: state.entity_id}
|
||||
|
||||
humidity = slots["humidity"]["value"]
|
||||
@ -85,12 +93,18 @@ class SetModeHandler(intent.IntentHandler):
|
||||
"""Handle the hass intent."""
|
||||
hass = intent_obj.hass
|
||||
slots = self.async_validate_slots(intent_obj.slots)
|
||||
state = intent.async_match_state(
|
||||
hass,
|
||||
slots["name"]["value"],
|
||||
hass.states.async_all(DOMAIN),
|
||||
states = list(
|
||||
intent.async_match_states(
|
||||
hass,
|
||||
name=slots["name"]["value"],
|
||||
states=hass.states.async_all(DOMAIN),
|
||||
)
|
||||
)
|
||||
|
||||
if not states:
|
||||
raise intent.IntentHandleError("No entities matched")
|
||||
|
||||
state = states[0]
|
||||
service_data = {ATTR_ENTITY_ID: state.entity_id}
|
||||
|
||||
intent.async_test_feature(state, HumidifierEntityFeature.MODES, "modes")
|
||||
|
@ -2,9 +2,19 @@
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import http
|
||||
from homeassistant.components.cover import (
|
||||
DOMAIN as COVER_DOMAIN,
|
||||
SERVICE_CLOSE_COVER,
|
||||
SERVICE_OPEN_COVER,
|
||||
)
|
||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||
from homeassistant.const import SERVICE_TOGGLE, SERVICE_TURN_OFF, SERVICE_TURN_ON
|
||||
from homeassistant.core import DOMAIN as HA_DOMAIN, HomeAssistant
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID,
|
||||
SERVICE_TOGGLE,
|
||||
SERVICE_TURN_OFF,
|
||||
SERVICE_TURN_ON,
|
||||
)
|
||||
from homeassistant.core import DOMAIN as HA_DOMAIN, HomeAssistant, State
|
||||
from homeassistant.helpers import config_validation as cv, integration_platform, intent
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
@ -21,13 +31,13 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
|
||||
intent.async_register(
|
||||
hass,
|
||||
intent.ServiceIntentHandler(
|
||||
OnOffIntentHandler(
|
||||
intent.INTENT_TURN_ON, HA_DOMAIN, SERVICE_TURN_ON, "Turned {} on"
|
||||
),
|
||||
)
|
||||
intent.async_register(
|
||||
hass,
|
||||
intent.ServiceIntentHandler(
|
||||
OnOffIntentHandler(
|
||||
intent.INTENT_TURN_OFF, HA_DOMAIN, SERVICE_TURN_OFF, "Turned {} off"
|
||||
),
|
||||
)
|
||||
@ -41,6 +51,29 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class OnOffIntentHandler(intent.ServiceIntentHandler):
|
||||
"""Intent handler for on/off that handles covers too."""
|
||||
|
||||
async def async_call_service(self, intent_obj: intent.Intent, state: State) -> None:
|
||||
"""Call service on entity with special case for covers."""
|
||||
hass = intent_obj.hass
|
||||
|
||||
if state.domain == COVER_DOMAIN:
|
||||
# on = open
|
||||
# off = close
|
||||
await hass.services.async_call(
|
||||
COVER_DOMAIN,
|
||||
SERVICE_OPEN_COVER
|
||||
if self.service == SERVICE_TURN_ON
|
||||
else SERVICE_CLOSE_COVER,
|
||||
{ATTR_ENTITY_ID: state.entity_id},
|
||||
context=intent_obj.context,
|
||||
)
|
||||
else:
|
||||
# Fall back to homeassistant.turn_on/off
|
||||
await super().async_call_service(intent_obj, state)
|
||||
|
||||
|
||||
async def _async_process_intent(hass: HomeAssistant, domain: str, platform):
|
||||
"""Process the intents of an integration."""
|
||||
await platform.async_setup_intents(hass)
|
||||
|
@ -1,12 +1,15 @@
|
||||
"""Intents for the light integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_ON
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
from homeassistant.helpers import intent
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import area_registry, config_validation as cv, intent
|
||||
import homeassistant.util.color as color_util
|
||||
|
||||
from . import (
|
||||
@ -18,6 +21,8 @@ from . import (
|
||||
color_supported,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
INTENT_SET = "HassLightSet"
|
||||
|
||||
|
||||
@ -26,30 +31,14 @@ async def async_setup_intents(hass: HomeAssistant) -> None:
|
||||
intent.async_register(hass, SetIntentHandler())
|
||||
|
||||
|
||||
def _test_supports_color(state: State) -> None:
|
||||
"""Test if state supports colors."""
|
||||
supported_color_modes = state.attributes.get(ATTR_SUPPORTED_COLOR_MODES)
|
||||
if not color_supported(supported_color_modes):
|
||||
raise intent.IntentHandleError(
|
||||
f"Entity {state.name} does not support changing colors"
|
||||
)
|
||||
|
||||
|
||||
def _test_supports_brightness(state: State) -> None:
|
||||
"""Test if state supports brightness."""
|
||||
supported_color_modes = state.attributes.get(ATTR_SUPPORTED_COLOR_MODES)
|
||||
if not brightness_supported(supported_color_modes):
|
||||
raise intent.IntentHandleError(
|
||||
f"Entity {state.name} does not support changing brightness"
|
||||
)
|
||||
|
||||
|
||||
class SetIntentHandler(intent.IntentHandler):
|
||||
"""Handle set color intents."""
|
||||
|
||||
intent_type = INTENT_SET
|
||||
slot_schema = {
|
||||
vol.Required("name"): cv.string,
|
||||
vol.Any("name", "area"): cv.string,
|
||||
vol.Optional("domain"): vol.All(cv.ensure_list, [cv.string]),
|
||||
vol.Optional("device_class"): vol.All(cv.ensure_list, [cv.string]),
|
||||
vol.Optional("color"): color_util.color_name_to_rgb,
|
||||
vol.Optional("brightness"): vol.All(vol.Coerce(int), vol.Range(0, 100)),
|
||||
}
|
||||
@ -57,36 +46,116 @@ class SetIntentHandler(intent.IntentHandler):
|
||||
async def async_handle(self, intent_obj: intent.Intent) -> intent.IntentResponse:
|
||||
"""Handle the hass intent."""
|
||||
hass = intent_obj.hass
|
||||
service_data: dict[str, Any] = {}
|
||||
speech_parts: list[str] = []
|
||||
slots = self.async_validate_slots(intent_obj.slots)
|
||||
state = intent.async_match_state(
|
||||
hass, slots["name"]["value"], hass.states.async_all(DOMAIN)
|
||||
|
||||
name: str | None = slots.get("name", {}).get("value")
|
||||
if name == "all":
|
||||
# Don't match on name if targeting all entities
|
||||
name = None
|
||||
|
||||
# Look up area first to fail early
|
||||
area_name = slots.get("area", {}).get("value")
|
||||
area: area_registry.AreaEntry | None = None
|
||||
if area_name is not None:
|
||||
areas = area_registry.async_get(hass)
|
||||
area = areas.async_get_area(area_name) or areas.async_get_area_by_name(
|
||||
area_name
|
||||
)
|
||||
if area is None:
|
||||
raise intent.IntentHandleError(f"No area named {area_name}")
|
||||
|
||||
# Optional domain/device class filters.
|
||||
# Convert to sets for speed.
|
||||
domains: set[str] | None = None
|
||||
device_classes: set[str] | None = None
|
||||
|
||||
if "domain" in slots:
|
||||
domains = set(slots["domain"]["value"])
|
||||
|
||||
if "device_class" in slots:
|
||||
device_classes = set(slots["device_class"]["value"])
|
||||
|
||||
states = list(
|
||||
intent.async_match_states(
|
||||
hass,
|
||||
name=name,
|
||||
area=area,
|
||||
domains=domains,
|
||||
device_classes=device_classes,
|
||||
)
|
||||
)
|
||||
|
||||
service_data = {ATTR_ENTITY_ID: state.entity_id}
|
||||
speech_parts = []
|
||||
if not states:
|
||||
raise intent.IntentHandleError("No entities matched")
|
||||
|
||||
if "color" in slots:
|
||||
_test_supports_color(state)
|
||||
service_data[ATTR_RGB_COLOR] = slots["color"]["value"]
|
||||
# Use original passed in value of the color because we don't have
|
||||
# human readable names for that internally.
|
||||
speech_parts.append(f"the color {intent_obj.slots['color']['value']}")
|
||||
|
||||
if "brightness" in slots:
|
||||
_test_supports_brightness(state)
|
||||
service_data[ATTR_BRIGHTNESS_PCT] = slots["brightness"]["value"]
|
||||
speech_parts.append(f"{slots['brightness']['value']}% brightness")
|
||||
|
||||
await hass.services.async_call(
|
||||
DOMAIN, SERVICE_TURN_ON, service_data, context=intent_obj.context
|
||||
response = intent_obj.create_response()
|
||||
needs_brightness = ATTR_BRIGHTNESS_PCT in service_data
|
||||
needs_color = ATTR_RGB_COLOR in service_data
|
||||
|
||||
success_results: list[intent.IntentResponseTarget] = []
|
||||
failed_results: list[intent.IntentResponseTarget] = []
|
||||
service_coros = []
|
||||
|
||||
if area is not None:
|
||||
success_results.append(
|
||||
intent.IntentResponseTarget(
|
||||
type=intent.IntentResponseTargetType.AREA,
|
||||
name=area.name,
|
||||
id=area.id,
|
||||
)
|
||||
)
|
||||
speech_name = area.name
|
||||
else:
|
||||
speech_name = states[0].name
|
||||
|
||||
for state in states:
|
||||
target = intent.IntentResponseTarget(
|
||||
type=intent.IntentResponseTargetType.ENTITY,
|
||||
name=state.name,
|
||||
id=state.entity_id,
|
||||
)
|
||||
|
||||
# Test brightness/color
|
||||
supported_color_modes = state.attributes.get(ATTR_SUPPORTED_COLOR_MODES)
|
||||
if (needs_color and not color_supported(supported_color_modes)) or (
|
||||
needs_brightness and not brightness_supported(supported_color_modes)
|
||||
):
|
||||
failed_results.append(target)
|
||||
continue
|
||||
|
||||
service_coros.append(
|
||||
hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_TURN_ON,
|
||||
{**service_data, ATTR_ENTITY_ID: state.entity_id},
|
||||
context=intent_obj.context,
|
||||
)
|
||||
)
|
||||
success_results.append(target)
|
||||
|
||||
# Handle service calls in parallel.
|
||||
await asyncio.gather(*service_coros)
|
||||
|
||||
response.async_set_results(
|
||||
success_results=success_results, failed_results=failed_results
|
||||
)
|
||||
|
||||
response = intent_obj.create_response()
|
||||
|
||||
if not speech_parts: # No attributes changed
|
||||
speech = f"Turned on {state.name}"
|
||||
speech = f"Turned on {speech_name}"
|
||||
else:
|
||||
parts = [f"Changed {state.name} to"]
|
||||
parts = [f"Changed {speech_name} to"]
|
||||
for index, part in enumerate(speech_parts):
|
||||
if index == 0:
|
||||
parts.append(f" {part}")
|
||||
@ -97,4 +166,5 @@ class SetIntentHandler(intent.IntentHandler):
|
||||
speech = "".join(parts)
|
||||
|
||||
response.async_set_speech(speech)
|
||||
|
||||
return response
|
||||
|
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Collection, Iterable
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
@ -11,7 +11,11 @@ from typing import Any, TypeVar
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import ATTR_ENTITY_ID, ATTR_SUPPORTED_FEATURES
|
||||
from homeassistant.const import (
|
||||
ATTR_DEVICE_CLASS,
|
||||
ATTR_ENTITY_ID,
|
||||
ATTR_SUPPORTED_FEATURES,
|
||||
)
|
||||
from homeassistant.core import Context, HomeAssistant, State, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.loader import bind_hass
|
||||
@ -110,51 +114,117 @@ class IntentUnexpectedError(IntentError):
|
||||
"""Unexpected error while handling intent."""
|
||||
|
||||
|
||||
def _is_device_class(
|
||||
state: State,
|
||||
entity: entity_registry.RegistryEntry | None,
|
||||
device_classes: Collection[str],
|
||||
) -> bool:
|
||||
"""Return true if entity device class matches."""
|
||||
# Try entity first
|
||||
if (entity is not None) and (entity.device_class is not None):
|
||||
# Entity device class can be None or blank as "unset"
|
||||
if entity.device_class in device_classes:
|
||||
return True
|
||||
|
||||
# Fall back to state attribute
|
||||
device_class = state.attributes.get(ATTR_DEVICE_CLASS)
|
||||
return (device_class is not None) and (device_class in device_classes)
|
||||
|
||||
|
||||
def _has_name(
|
||||
state: State, entity: entity_registry.RegistryEntry | None, name: str
|
||||
) -> bool:
|
||||
"""Return true if entity name or alias matches."""
|
||||
if name in (state.entity_id, state.name.casefold()):
|
||||
return True
|
||||
|
||||
# Check aliases
|
||||
if (entity is not None) and entity.aliases:
|
||||
for alias in entity.aliases:
|
||||
if name == alias.casefold():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_match_state(
|
||||
hass: HomeAssistant, name: str, states: Iterable[State] | None = None
|
||||
) -> State:
|
||||
"""Find a state that matches the name."""
|
||||
def async_match_states(
|
||||
hass: HomeAssistant,
|
||||
name: str | None = None,
|
||||
area_name: str | None = None,
|
||||
area: area_registry.AreaEntry | None = None,
|
||||
domains: Collection[str] | None = None,
|
||||
device_classes: Collection[str] | None = None,
|
||||
states: Iterable[State] | None = None,
|
||||
entities: entity_registry.EntityRegistry | None = None,
|
||||
areas: area_registry.AreaRegistry | None = None,
|
||||
) -> Iterable[State]:
|
||||
"""Find states that match the constraints."""
|
||||
if states is None:
|
||||
# All states
|
||||
states = hass.states.async_all()
|
||||
|
||||
name = name.casefold()
|
||||
state: State | None = None
|
||||
registry = entity_registry.async_get(hass)
|
||||
if entities is None:
|
||||
entities = entity_registry.async_get(hass)
|
||||
|
||||
for maybe_state in states:
|
||||
# Check entity id and name
|
||||
if name in (maybe_state.entity_id, maybe_state.name.casefold()):
|
||||
state = maybe_state
|
||||
else:
|
||||
# Check aliases
|
||||
entry = registry.async_get(maybe_state.entity_id)
|
||||
if (entry is not None) and entry.aliases:
|
||||
for alias in entry.aliases:
|
||||
if name == alias.casefold():
|
||||
state = maybe_state
|
||||
break
|
||||
# Gather entities
|
||||
states_and_entities: list[tuple[State, entity_registry.RegistryEntry | None]] = []
|
||||
for state in states:
|
||||
entity = entities.async_get(state.entity_id)
|
||||
if (entity is not None) and entity.entity_category:
|
||||
# Skip diagnostic entities
|
||||
continue
|
||||
|
||||
if state is not None:
|
||||
break
|
||||
states_and_entities.append((state, entity))
|
||||
|
||||
if state is None:
|
||||
raise IntentHandleError(f"Unable to find an entity called {name}")
|
||||
# Filter by domain and device class
|
||||
if domains:
|
||||
states_and_entities = [
|
||||
(state, entity)
|
||||
for state, entity in states_and_entities
|
||||
if state.domain in domains
|
||||
]
|
||||
|
||||
return state
|
||||
if device_classes:
|
||||
# Check device class in state attribute and in entity entry (if available)
|
||||
states_and_entities = [
|
||||
(state, entity)
|
||||
for state, entity in states_and_entities
|
||||
if _is_device_class(state, entity, device_classes)
|
||||
]
|
||||
|
||||
if (area is None) and (area_name is not None):
|
||||
# Look up area by name
|
||||
if areas is None:
|
||||
areas = area_registry.async_get(hass)
|
||||
|
||||
@callback
|
||||
@bind_hass
|
||||
def async_match_area(
|
||||
hass: HomeAssistant, area_name: str
|
||||
) -> area_registry.AreaEntry | None:
|
||||
"""Find an area that matches the name."""
|
||||
registry = area_registry.async_get(hass)
|
||||
return registry.async_get_area(area_name) or registry.async_get_area_by_name(
|
||||
area_name
|
||||
)
|
||||
# id or name
|
||||
area = areas.async_get_area(area_name) or areas.async_get_area_by_name(
|
||||
area_name
|
||||
)
|
||||
assert area is not None, f"No area named {area_name}"
|
||||
|
||||
if area is not None:
|
||||
# Filter by area
|
||||
states_and_entities = [
|
||||
(state, entity)
|
||||
for state, entity in states_and_entities
|
||||
if (entity is not None) and (entity.area_id == area.id)
|
||||
]
|
||||
|
||||
if name is not None:
|
||||
# Filter by name
|
||||
name = name.casefold()
|
||||
|
||||
for state, entity in states_and_entities:
|
||||
if _has_name(state, entity, name):
|
||||
yield state
|
||||
break
|
||||
else:
|
||||
# Not filtered by name
|
||||
for state, _entity in states_and_entities:
|
||||
yield state
|
||||
|
||||
|
||||
@callback
|
||||
@ -229,102 +299,103 @@ class ServiceIntentHandler(IntentHandler):
|
||||
hass = intent_obj.hass
|
||||
slots = self.async_validate_slots(intent_obj.slots)
|
||||
|
||||
if "area" in slots:
|
||||
# Entities in an area
|
||||
area_name = slots["area"]["value"]
|
||||
area = async_match_area(hass, area_name)
|
||||
assert area is not None
|
||||
assert area.id is not None
|
||||
name: str | None = slots.get("name", {}).get("value")
|
||||
if name == "all":
|
||||
# Don't match on name if targeting all entities
|
||||
name = None
|
||||
|
||||
# Optional domain filter
|
||||
domains: set[str] | None = None
|
||||
if "domain" in slots:
|
||||
domains = set(slots["domain"]["value"])
|
||||
# Look up area first to fail early
|
||||
area_name = slots.get("area", {}).get("value")
|
||||
area: area_registry.AreaEntry | None = None
|
||||
if area_name is not None:
|
||||
areas = area_registry.async_get(hass)
|
||||
area = areas.async_get_area(area_name) or areas.async_get_area_by_name(
|
||||
area_name
|
||||
)
|
||||
if area is None:
|
||||
raise IntentHandleError(f"No area named {area_name}")
|
||||
|
||||
# Optional device class filter
|
||||
device_classes: set[str] | None = None
|
||||
if "device_class" in slots:
|
||||
device_classes = set(slots["device_class"]["value"])
|
||||
# Optional domain/device class filters.
|
||||
# Convert to sets for speed.
|
||||
domains: set[str] | None = None
|
||||
device_classes: set[str] | None = None
|
||||
|
||||
success_results = [
|
||||
if "domain" in slots:
|
||||
domains = set(slots["domain"]["value"])
|
||||
|
||||
if "device_class" in slots:
|
||||
device_classes = set(slots["device_class"]["value"])
|
||||
|
||||
states = list(
|
||||
async_match_states(
|
||||
hass,
|
||||
name=name,
|
||||
area=area,
|
||||
domains=domains,
|
||||
device_classes=device_classes,
|
||||
)
|
||||
)
|
||||
|
||||
if not states:
|
||||
raise IntentHandleError("No entities matched")
|
||||
|
||||
response = await self.async_handle_states(intent_obj, states, area)
|
||||
|
||||
return response
|
||||
|
||||
async def async_handle_states(
|
||||
self,
|
||||
intent_obj: Intent,
|
||||
states: list[State],
|
||||
area: area_registry.AreaEntry | None = None,
|
||||
) -> IntentResponse:
|
||||
"""Complete action on matched entity states."""
|
||||
assert states
|
||||
success_results: list[IntentResponseTarget] = []
|
||||
response = intent_obj.create_response()
|
||||
|
||||
if area is not None:
|
||||
success_results.append(
|
||||
IntentResponseTarget(
|
||||
type=IntentResponseTargetType.AREA, name=area.name, id=area.id
|
||||
)
|
||||
]
|
||||
service_coros = []
|
||||
registry = entity_registry.async_get(hass)
|
||||
for entity_entry in entity_registry.async_entries_for_area(
|
||||
registry, area.id
|
||||
):
|
||||
if entity_entry.entity_category:
|
||||
# Skip diagnostic entities
|
||||
continue
|
||||
|
||||
if domains and (entity_entry.domain not in domains):
|
||||
# Skip entity not in the domain
|
||||
continue
|
||||
|
||||
if device_classes and (entity_entry.device_class not in device_classes):
|
||||
# Skip entity with wrong device class
|
||||
continue
|
||||
|
||||
service_coros.append(
|
||||
hass.services.async_call(
|
||||
self.domain,
|
||||
self.service,
|
||||
{ATTR_ENTITY_ID: entity_entry.entity_id},
|
||||
context=intent_obj.context,
|
||||
)
|
||||
)
|
||||
|
||||
state = hass.states.get(entity_entry.entity_id)
|
||||
assert state is not None
|
||||
|
||||
success_results.append(
|
||||
IntentResponseTarget(
|
||||
type=IntentResponseTargetType.ENTITY,
|
||||
name=state.name,
|
||||
id=entity_entry.entity_id,
|
||||
),
|
||||
)
|
||||
|
||||
if not service_coros:
|
||||
raise IntentHandleError("No entities matched")
|
||||
|
||||
# Handle service calls in parallel.
|
||||
# We will need to handle partial failures here.
|
||||
await asyncio.gather(*service_coros)
|
||||
|
||||
response = intent_obj.create_response()
|
||||
response.async_set_speech(self.speech.format(area.name))
|
||||
response.async_set_results(
|
||||
success_results=success_results,
|
||||
)
|
||||
speech_name = area.name
|
||||
else:
|
||||
# Single entity
|
||||
state = async_match_state(hass, slots["name"]["value"])
|
||||
speech_name = states[0].name
|
||||
|
||||
await hass.services.async_call(
|
||||
self.domain,
|
||||
self.service,
|
||||
{ATTR_ENTITY_ID: state.entity_id},
|
||||
context=intent_obj.context,
|
||||
service_coros = []
|
||||
for state in states:
|
||||
service_coros.append(self.async_call_service(intent_obj, state))
|
||||
success_results.append(
|
||||
IntentResponseTarget(
|
||||
type=IntentResponseTargetType.ENTITY,
|
||||
name=state.name,
|
||||
id=state.entity_id,
|
||||
),
|
||||
)
|
||||
|
||||
response = intent_obj.create_response()
|
||||
response.async_set_speech(self.speech.format(state.name))
|
||||
response.async_set_results(
|
||||
success_results=[
|
||||
IntentResponseTarget(
|
||||
type=IntentResponseTargetType.ENTITY,
|
||||
name=state.name,
|
||||
id=state.entity_id,
|
||||
),
|
||||
],
|
||||
)
|
||||
# Handle service calls in parallel.
|
||||
# We will need to handle partial failures here.
|
||||
await asyncio.gather(*service_coros)
|
||||
|
||||
response.async_set_results(
|
||||
success_results=success_results,
|
||||
)
|
||||
response.async_set_speech(self.speech.format(speech_name))
|
||||
|
||||
return response
|
||||
|
||||
async def async_call_service(self, intent_obj: Intent, state: State) -> None:
|
||||
"""Call service on entity."""
|
||||
hass = intent_obj.hass
|
||||
await hass.services.async_call(
|
||||
self.domain,
|
||||
self.service,
|
||||
{ATTR_ENTITY_ID: state.entity_id},
|
||||
context=intent_obj.context,
|
||||
)
|
||||
|
||||
|
||||
class IntentCategory(Enum):
|
||||
"""Category of an intent."""
|
||||
|
@ -2,7 +2,7 @@
|
||||
from homeassistant.components import light
|
||||
from homeassistant.components.light import ATTR_SUPPORTED_COLOR_MODES, ColorMode, intent
|
||||
from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_ON
|
||||
from homeassistant.helpers.intent import IntentHandleError, async_handle
|
||||
from homeassistant.helpers.intent import async_handle
|
||||
|
||||
from tests.common import async_mock_service
|
||||
|
||||
@ -40,17 +40,16 @@ async def test_intent_set_color_tests_feature(hass):
|
||||
calls = async_mock_service(hass, light.DOMAIN, light.SERVICE_TURN_ON)
|
||||
await intent.async_setup_intents(hass)
|
||||
|
||||
try:
|
||||
await async_handle(
|
||||
hass,
|
||||
"test",
|
||||
intent.INTENT_SET,
|
||||
{"name": {"value": "Hello"}, "color": {"value": "blue"}},
|
||||
)
|
||||
assert False, "handling intent should have raised"
|
||||
except IntentHandleError as err:
|
||||
assert str(err) == "Entity hello does not support changing colors"
|
||||
response = await async_handle(
|
||||
hass,
|
||||
"test",
|
||||
intent.INTENT_SET,
|
||||
{"name": {"value": "Hello"}, "color": {"value": "blue"}},
|
||||
)
|
||||
|
||||
# Response should contain one failed target
|
||||
assert len(response.success_results) == 0
|
||||
assert len(response.failed_results) == 1
|
||||
assert len(calls) == 0
|
||||
|
||||
|
||||
|
@ -3,9 +3,15 @@
|
||||
import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.switch import SwitchDeviceClass
|
||||
from homeassistant.const import ATTR_FRIENDLY_NAME
|
||||
from homeassistant.core import State
|
||||
from homeassistant.helpers import config_validation as cv, entity_registry, intent
|
||||
from homeassistant.helpers import (
|
||||
area_registry,
|
||||
config_validation as cv,
|
||||
entity_registry,
|
||||
intent,
|
||||
)
|
||||
|
||||
|
||||
class MockIntentHandler(intent.IntentHandler):
|
||||
@ -16,25 +22,74 @@ class MockIntentHandler(intent.IntentHandler):
|
||||
self.slot_schema = slot_schema
|
||||
|
||||
|
||||
async def test_async_match_state(hass):
|
||||
async def test_async_match_states(hass):
|
||||
"""Test async_match_state helper."""
|
||||
areas = area_registry.async_get(hass)
|
||||
area_kitchen = areas.async_get_or_create("kitchen")
|
||||
area_bedroom = areas.async_get_or_create("bedroom")
|
||||
|
||||
state1 = State(
|
||||
"light.kitchen", "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
|
||||
)
|
||||
state2 = State(
|
||||
"switch.kitchen", "on", attributes={ATTR_FRIENDLY_NAME: "kitchen switch"}
|
||||
"switch.bedroom", "on", attributes={ATTR_FRIENDLY_NAME: "bedroom switch"}
|
||||
)
|
||||
registry = entity_registry.async_get(hass)
|
||||
registry.async_get_or_create(
|
||||
"switch", "demo", "1234", suggested_object_id="kitchen"
|
||||
|
||||
# Put entities into different areas
|
||||
entities = entity_registry.async_get(hass)
|
||||
entities.async_get_or_create("light", "demo", "1234", suggested_object_id="kitchen")
|
||||
entities.async_update_entity(state1.entity_id, area_id=area_kitchen.id)
|
||||
|
||||
entities.async_get_or_create(
|
||||
"switch", "demo", "1234", suggested_object_id="bedroom"
|
||||
)
|
||||
entities.async_update_entity(
|
||||
state2.entity_id,
|
||||
area_id=area_bedroom.id,
|
||||
device_class=SwitchDeviceClass.OUTLET,
|
||||
aliases={"kill switch"},
|
||||
)
|
||||
registry.async_update_entity(state2.entity_id, aliases={"kill switch"})
|
||||
|
||||
state = intent.async_match_state(hass, "kitchen light", [state1, state2])
|
||||
assert state is state1
|
||||
# Match on name
|
||||
assert [state1] == list(
|
||||
intent.async_match_states(hass, name="kitchen light", states=[state1, state2])
|
||||
)
|
||||
|
||||
state = intent.async_match_state(hass, "kill switch", [state1, state2])
|
||||
assert state is state2
|
||||
# Test alias
|
||||
assert [state2] == list(
|
||||
intent.async_match_states(hass, name="kill switch", states=[state1, state2])
|
||||
)
|
||||
|
||||
# Name + area
|
||||
assert [state1] == list(
|
||||
intent.async_match_states(
|
||||
hass, name="kitchen light", area_name="kitchen", states=[state1, state2]
|
||||
)
|
||||
)
|
||||
|
||||
# Wrong area
|
||||
assert not list(
|
||||
intent.async_match_states(
|
||||
hass, name="kitchen light", area_name="bedroom", states=[state1, state2]
|
||||
)
|
||||
)
|
||||
|
||||
# Domain + area
|
||||
assert [state2] == list(
|
||||
intent.async_match_states(
|
||||
hass, domains={"switch"}, area_name="bedroom", states=[state1, state2]
|
||||
)
|
||||
)
|
||||
|
||||
# Device class + area
|
||||
assert [state2] == list(
|
||||
intent.async_match_states(
|
||||
hass,
|
||||
device_classes={SwitchDeviceClass.OUTLET},
|
||||
area_name="bedroom",
|
||||
states=[state1, state2],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_async_validate_slots():
|
||||
|
Loading…
x
Reference in New Issue
Block a user