HassTurnOn/Off intents to also handle cover entities (#86206)

* Move entity/area resolution to async_match_states

* Special case for covers in HassTurnOn/Off

* Enable light color/brightness on areas

* Remove async_register from default agent

* Remove CONFIG_SCHEMA from conversation component

* Fix intent tests

* Fix light test

* Move entity/area resolution to async_match_states

* Special case for covers in HassTurnOn/Off

* Enable light color/brightness on areas

* Remove async_register from default agent

* Remove CONFIG_SCHEMA from conversation component

* Fix intent tests

* Fix light test

* Fix humidifier intent handlers

* Remove DATA_CONFIG for conversation

* Copy ServiceIntentHandler code to light

* Add proper errors to humidifier intent handlers
This commit is contained in:
Michael Hansen 2023-01-19 17:15:01 -06:00 committed by GitHub
parent 8f10c22a23
commit 5aca996f22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 429 additions and 228 deletions

View File

@ -16,7 +16,7 @@ from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from .agent import AbstractConversationAgent, ConversationResult from .agent import AbstractConversationAgent, ConversationResult
from .default_agent import DefaultAgent, async_register from .default_agent import DefaultAgent
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -27,7 +27,6 @@ DOMAIN = "conversation"
REGEX_TYPE = type(re.compile("")) REGEX_TYPE = type(re.compile(""))
DATA_AGENT = "conversation_agent" DATA_AGENT = "conversation_agent"
DATA_CONFIG = "conversation_config"
SERVICE_PROCESS = "process" SERVICE_PROCESS = "process"
SERVICE_RELOAD = "reload" SERVICE_RELOAD = "reload"
@ -47,22 +46,6 @@ SERVICE_RELOAD_SCHEMA = vol.Schema(
) )
CONFIG_SCHEMA = vol.Schema(
{
DOMAIN: vol.Schema(
{
vol.Optional("intents"): vol.Schema(
{cv.string: vol.All(cv.ensure_list, [cv.string])}
)
}
)
},
extra=vol.ALLOW_EXTRA,
)
async_register = bind_hass(async_register)
@core.callback @core.callback
@bind_hass @bind_hass
def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent | None): def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent | None):
@ -72,7 +55,6 @@ def async_set_agent(hass: core.HomeAssistant, agent: AbstractConversationAgent |
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Register the process service.""" """Register the process service."""
hass.data[DATA_CONFIG] = config
async def handle_process(service: core.ServiceCall) -> None: async def handle_process(service: core.ServiceCall) -> None:
"""Parse text into commands.""" """Parse text into commands."""
@ -228,7 +210,7 @@ async def _get_agent(hass: core.HomeAssistant) -> AbstractConversationAgent:
"""Get the active conversation agent.""" """Get the active conversation agent."""
if (agent := hass.data.get(DATA_AGENT)) is None: if (agent := hass.data.get(DATA_AGENT)) is None:
agent = hass.data[DATA_AGENT] = DefaultAgent(hass) agent = hass.data[DATA_AGENT] = DefaultAgent(hass)
await agent.async_initialize(hass.data.get(DATA_CONFIG)) await agent.async_initialize()
return agent return agent

View File

@ -20,30 +20,12 @@ from homeassistant.helpers import area_registry, entity_registry, intent
from .agent import AbstractConversationAgent, ConversationResult from .agent import AbstractConversationAgent, ConversationResult
from .const import DOMAIN from .const import DOMAIN
from .util import create_matcher
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
REGEX_TYPE = type(re.compile("")) REGEX_TYPE = type(re.compile(""))
@core.callback
def async_register(hass, intent_type, utterances):
"""Register utterances and any custom intents for the default agent.
Registrations don't require conversations to be loaded. They will become
active once the conversation component is loaded.
"""
intents = hass.data.setdefault(DOMAIN, {})
conf = intents.setdefault(intent_type, [])
for utterance in utterances:
if isinstance(utterance, REGEX_TYPE):
conf.append(utterance)
else:
conf.append(create_matcher(utterance))
@dataclass @dataclass
class LanguageIntents: class LanguageIntents:
"""Loaded intents for a language.""" """Loaded intents for a language."""
@ -62,16 +44,11 @@ class DefaultAgent(AbstractConversationAgent):
self._lang_intents: dict[str, LanguageIntents] = {} self._lang_intents: dict[str, LanguageIntents] = {}
self._lang_lock: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) self._lang_lock: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
async def async_initialize(self, config): async def async_initialize(self):
"""Initialize the default agent.""" """Initialize the default agent."""
if "intent" not in self.hass.config.components: if "intent" not in self.hass.config.components:
await setup.async_setup_component(self.hass, "intent", {}) await setup.async_setup_component(self.hass, "intent", {})
if config and config.get(DOMAIN):
_LOGGER.warning(
"Custom intent sentences have been moved to config/custom_sentences"
)
self.hass.data.setdefault(DOMAIN, {}) self.hass.data.setdefault(DOMAIN, {})
async def async_process( async def async_process(

View File

@ -41,10 +41,18 @@ class HumidityHandler(intent.IntentHandler):
"""Handle the hass intent.""" """Handle the hass intent."""
hass = intent_obj.hass hass = intent_obj.hass
slots = self.async_validate_slots(intent_obj.slots) slots = self.async_validate_slots(intent_obj.slots)
state = intent.async_match_state( states = list(
hass, slots["name"]["value"], hass.states.async_all(DOMAIN) intent.async_match_states(
hass,
name=slots["name"]["value"],
states=hass.states.async_all(DOMAIN),
)
) )
if not states:
raise intent.IntentHandleError("No entities matched")
state = states[0]
service_data = {ATTR_ENTITY_ID: state.entity_id} service_data = {ATTR_ENTITY_ID: state.entity_id}
humidity = slots["humidity"]["value"] humidity = slots["humidity"]["value"]
@ -85,12 +93,18 @@ class SetModeHandler(intent.IntentHandler):
"""Handle the hass intent.""" """Handle the hass intent."""
hass = intent_obj.hass hass = intent_obj.hass
slots = self.async_validate_slots(intent_obj.slots) slots = self.async_validate_slots(intent_obj.slots)
state = intent.async_match_state( states = list(
intent.async_match_states(
hass, hass,
slots["name"]["value"], name=slots["name"]["value"],
hass.states.async_all(DOMAIN), states=hass.states.async_all(DOMAIN),
)
) )
if not states:
raise intent.IntentHandleError("No entities matched")
state = states[0]
service_data = {ATTR_ENTITY_ID: state.entity_id} service_data = {ATTR_ENTITY_ID: state.entity_id}
intent.async_test_feature(state, HumidifierEntityFeature.MODES, "modes") intent.async_test_feature(state, HumidifierEntityFeature.MODES, "modes")

View File

@ -2,9 +2,19 @@
import voluptuous as vol import voluptuous as vol
from homeassistant.components import http from homeassistant.components import http
from homeassistant.components.cover import (
DOMAIN as COVER_DOMAIN,
SERVICE_CLOSE_COVER,
SERVICE_OPEN_COVER,
)
from homeassistant.components.http.data_validator import RequestDataValidator from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.const import SERVICE_TOGGLE, SERVICE_TURN_OFF, SERVICE_TURN_ON from homeassistant.const import (
from homeassistant.core import DOMAIN as HA_DOMAIN, HomeAssistant ATTR_ENTITY_ID,
SERVICE_TOGGLE,
SERVICE_TURN_OFF,
SERVICE_TURN_ON,
)
from homeassistant.core import DOMAIN as HA_DOMAIN, HomeAssistant, State
from homeassistant.helpers import config_validation as cv, integration_platform, intent from homeassistant.helpers import config_validation as cv, integration_platform, intent
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
@ -21,13 +31,13 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
intent.async_register( intent.async_register(
hass, hass,
intent.ServiceIntentHandler( OnOffIntentHandler(
intent.INTENT_TURN_ON, HA_DOMAIN, SERVICE_TURN_ON, "Turned {} on" intent.INTENT_TURN_ON, HA_DOMAIN, SERVICE_TURN_ON, "Turned {} on"
), ),
) )
intent.async_register( intent.async_register(
hass, hass,
intent.ServiceIntentHandler( OnOffIntentHandler(
intent.INTENT_TURN_OFF, HA_DOMAIN, SERVICE_TURN_OFF, "Turned {} off" intent.INTENT_TURN_OFF, HA_DOMAIN, SERVICE_TURN_OFF, "Turned {} off"
), ),
) )
@ -41,6 +51,29 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
class OnOffIntentHandler(intent.ServiceIntentHandler):
"""Intent handler for on/off that handles covers too."""
async def async_call_service(self, intent_obj: intent.Intent, state: State) -> None:
"""Call service on entity with special case for covers."""
hass = intent_obj.hass
if state.domain == COVER_DOMAIN:
# on = open
# off = close
await hass.services.async_call(
COVER_DOMAIN,
SERVICE_OPEN_COVER
if self.service == SERVICE_TURN_ON
else SERVICE_CLOSE_COVER,
{ATTR_ENTITY_ID: state.entity_id},
context=intent_obj.context,
)
else:
# Fall back to homeassistant.turn_on/off
await super().async_call_service(intent_obj, state)
async def _async_process_intent(hass: HomeAssistant, domain: str, platform): async def _async_process_intent(hass: HomeAssistant, domain: str, platform):
"""Process the intents of an integration.""" """Process the intents of an integration."""
await platform.async_setup_intents(hass) await platform.async_setup_intents(hass)

View File

@ -1,12 +1,15 @@
"""Intents for the light integration.""" """Intents for the light integration."""
from __future__ import annotations from __future__ import annotations
import asyncio
import logging
from typing import Any
import voluptuous as vol import voluptuous as vol
from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_ON from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_ON
from homeassistant.core import HomeAssistant, State from homeassistant.core import HomeAssistant
from homeassistant.helpers import intent from homeassistant.helpers import area_registry, config_validation as cv, intent
import homeassistant.helpers.config_validation as cv
import homeassistant.util.color as color_util import homeassistant.util.color as color_util
from . import ( from . import (
@ -18,6 +21,8 @@ from . import (
color_supported, color_supported,
) )
_LOGGER = logging.getLogger(__name__)
INTENT_SET = "HassLightSet" INTENT_SET = "HassLightSet"
@ -26,30 +31,14 @@ async def async_setup_intents(hass: HomeAssistant) -> None:
intent.async_register(hass, SetIntentHandler()) intent.async_register(hass, SetIntentHandler())
def _test_supports_color(state: State) -> None:
"""Test if state supports colors."""
supported_color_modes = state.attributes.get(ATTR_SUPPORTED_COLOR_MODES)
if not color_supported(supported_color_modes):
raise intent.IntentHandleError(
f"Entity {state.name} does not support changing colors"
)
def _test_supports_brightness(state: State) -> None:
"""Test if state supports brightness."""
supported_color_modes = state.attributes.get(ATTR_SUPPORTED_COLOR_MODES)
if not brightness_supported(supported_color_modes):
raise intent.IntentHandleError(
f"Entity {state.name} does not support changing brightness"
)
class SetIntentHandler(intent.IntentHandler): class SetIntentHandler(intent.IntentHandler):
"""Handle set color intents.""" """Handle set color intents."""
intent_type = INTENT_SET intent_type = INTENT_SET
slot_schema = { slot_schema = {
vol.Required("name"): cv.string, vol.Any("name", "area"): cv.string,
vol.Optional("domain"): vol.All(cv.ensure_list, [cv.string]),
vol.Optional("device_class"): vol.All(cv.ensure_list, [cv.string]),
vol.Optional("color"): color_util.color_name_to_rgb, vol.Optional("color"): color_util.color_name_to_rgb,
vol.Optional("brightness"): vol.All(vol.Coerce(int), vol.Range(0, 100)), vol.Optional("brightness"): vol.All(vol.Coerce(int), vol.Range(0, 100)),
} }
@ -57,36 +46,116 @@ class SetIntentHandler(intent.IntentHandler):
async def async_handle(self, intent_obj: intent.Intent) -> intent.IntentResponse: async def async_handle(self, intent_obj: intent.Intent) -> intent.IntentResponse:
"""Handle the hass intent.""" """Handle the hass intent."""
hass = intent_obj.hass hass = intent_obj.hass
service_data: dict[str, Any] = {}
speech_parts: list[str] = []
slots = self.async_validate_slots(intent_obj.slots) slots = self.async_validate_slots(intent_obj.slots)
state = intent.async_match_state(
hass, slots["name"]["value"], hass.states.async_all(DOMAIN) name: str | None = slots.get("name", {}).get("value")
if name == "all":
# Don't match on name if targeting all entities
name = None
# Look up area first to fail early
area_name = slots.get("area", {}).get("value")
area: area_registry.AreaEntry | None = None
if area_name is not None:
areas = area_registry.async_get(hass)
area = areas.async_get_area(area_name) or areas.async_get_area_by_name(
area_name
)
if area is None:
raise intent.IntentHandleError(f"No area named {area_name}")
# Optional domain/device class filters.
# Convert to sets for speed.
domains: set[str] | None = None
device_classes: set[str] | None = None
if "domain" in slots:
domains = set(slots["domain"]["value"])
if "device_class" in slots:
device_classes = set(slots["device_class"]["value"])
states = list(
intent.async_match_states(
hass,
name=name,
area=area,
domains=domains,
device_classes=device_classes,
)
) )
service_data = {ATTR_ENTITY_ID: state.entity_id} if not states:
speech_parts = [] raise intent.IntentHandleError("No entities matched")
if "color" in slots: if "color" in slots:
_test_supports_color(state)
service_data[ATTR_RGB_COLOR] = slots["color"]["value"] service_data[ATTR_RGB_COLOR] = slots["color"]["value"]
# Use original passed in value of the color because we don't have # Use original passed in value of the color because we don't have
# human readable names for that internally. # human readable names for that internally.
speech_parts.append(f"the color {intent_obj.slots['color']['value']}") speech_parts.append(f"the color {intent_obj.slots['color']['value']}")
if "brightness" in slots: if "brightness" in slots:
_test_supports_brightness(state)
service_data[ATTR_BRIGHTNESS_PCT] = slots["brightness"]["value"] service_data[ATTR_BRIGHTNESS_PCT] = slots["brightness"]["value"]
speech_parts.append(f"{slots['brightness']['value']}% brightness") speech_parts.append(f"{slots['brightness']['value']}% brightness")
await hass.services.async_call( response = intent_obj.create_response()
DOMAIN, SERVICE_TURN_ON, service_data, context=intent_obj.context needs_brightness = ATTR_BRIGHTNESS_PCT in service_data
needs_color = ATTR_RGB_COLOR in service_data
success_results: list[intent.IntentResponseTarget] = []
failed_results: list[intent.IntentResponseTarget] = []
service_coros = []
if area is not None:
success_results.append(
intent.IntentResponseTarget(
type=intent.IntentResponseTargetType.AREA,
name=area.name,
id=area.id,
)
)
speech_name = area.name
else:
speech_name = states[0].name
for state in states:
target = intent.IntentResponseTarget(
type=intent.IntentResponseTargetType.ENTITY,
name=state.name,
id=state.entity_id,
) )
response = intent_obj.create_response() # Test brightness/color
supported_color_modes = state.attributes.get(ATTR_SUPPORTED_COLOR_MODES)
if (needs_color and not color_supported(supported_color_modes)) or (
needs_brightness and not brightness_supported(supported_color_modes)
):
failed_results.append(target)
continue
service_coros.append(
hass.services.async_call(
DOMAIN,
SERVICE_TURN_ON,
{**service_data, ATTR_ENTITY_ID: state.entity_id},
context=intent_obj.context,
)
)
success_results.append(target)
# Handle service calls in parallel.
await asyncio.gather(*service_coros)
response.async_set_results(
success_results=success_results, failed_results=failed_results
)
if not speech_parts: # No attributes changed if not speech_parts: # No attributes changed
speech = f"Turned on {state.name}" speech = f"Turned on {speech_name}"
else: else:
parts = [f"Changed {state.name} to"] parts = [f"Changed {speech_name} to"]
for index, part in enumerate(speech_parts): for index, part in enumerate(speech_parts):
if index == 0: if index == 0:
parts.append(f" {part}") parts.append(f" {part}")
@ -97,4 +166,5 @@ class SetIntentHandler(intent.IntentHandler):
speech = "".join(parts) speech = "".join(parts)
response.async_set_speech(speech) response.async_set_speech(speech)
return response return response

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Iterable from collections.abc import Collection, Iterable
import dataclasses import dataclasses
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
@ -11,7 +11,11 @@ from typing import Any, TypeVar
import voluptuous as vol import voluptuous as vol
from homeassistant.const import ATTR_ENTITY_ID, ATTR_SUPPORTED_FEATURES from homeassistant.const import (
ATTR_DEVICE_CLASS,
ATTR_ENTITY_ID,
ATTR_SUPPORTED_FEATURES,
)
from homeassistant.core import Context, HomeAssistant, State, callback from homeassistant.core import Context, HomeAssistant, State, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
@ -110,51 +114,117 @@ class IntentUnexpectedError(IntentError):
"""Unexpected error while handling intent.""" """Unexpected error while handling intent."""
def _is_device_class(
state: State,
entity: entity_registry.RegistryEntry | None,
device_classes: Collection[str],
) -> bool:
"""Return true if entity device class matches."""
# Try entity first
if (entity is not None) and (entity.device_class is not None):
# Entity device class can be None or blank as "unset"
if entity.device_class in device_classes:
return True
# Fall back to state attribute
device_class = state.attributes.get(ATTR_DEVICE_CLASS)
return (device_class is not None) and (device_class in device_classes)
def _has_name(
state: State, entity: entity_registry.RegistryEntry | None, name: str
) -> bool:
"""Return true if entity name or alias matches."""
if name in (state.entity_id, state.name.casefold()):
return True
# Check aliases
if (entity is not None) and entity.aliases:
for alias in entity.aliases:
if name == alias.casefold():
return True
return False
@callback @callback
@bind_hass @bind_hass
def async_match_state( def async_match_states(
hass: HomeAssistant, name: str, states: Iterable[State] | None = None hass: HomeAssistant,
) -> State: name: str | None = None,
"""Find a state that matches the name.""" area_name: str | None = None,
area: area_registry.AreaEntry | None = None,
domains: Collection[str] | None = None,
device_classes: Collection[str] | None = None,
states: Iterable[State] | None = None,
entities: entity_registry.EntityRegistry | None = None,
areas: area_registry.AreaRegistry | None = None,
) -> Iterable[State]:
"""Find states that match the constraints."""
if states is None: if states is None:
# All states
states = hass.states.async_all() states = hass.states.async_all()
name = name.casefold() if entities is None:
state: State | None = None entities = entity_registry.async_get(hass)
registry = entity_registry.async_get(hass)
for maybe_state in states: # Gather entities
# Check entity id and name states_and_entities: list[tuple[State, entity_registry.RegistryEntry | None]] = []
if name in (maybe_state.entity_id, maybe_state.name.casefold()): for state in states:
state = maybe_state entity = entities.async_get(state.entity_id)
else: if (entity is not None) and entity.entity_category:
# Check aliases # Skip diagnostic entities
entry = registry.async_get(maybe_state.entity_id) continue
if (entry is not None) and entry.aliases:
for alias in entry.aliases:
if name == alias.casefold():
state = maybe_state
break
if state is not None: states_and_entities.append((state, entity))
break
if state is None: # Filter by domain and device class
raise IntentHandleError(f"Unable to find an entity called {name}") if domains:
states_and_entities = [
(state, entity)
for state, entity in states_and_entities
if state.domain in domains
]
return state if device_classes:
# Check device class in state attribute and in entity entry (if available)
states_and_entities = [
(state, entity)
for state, entity in states_and_entities
if _is_device_class(state, entity, device_classes)
]
if (area is None) and (area_name is not None):
# Look up area by name
if areas is None:
areas = area_registry.async_get(hass)
@callback # id or name
@bind_hass area = areas.async_get_area(area_name) or areas.async_get_area_by_name(
def async_match_area(
hass: HomeAssistant, area_name: str
) -> area_registry.AreaEntry | None:
"""Find an area that matches the name."""
registry = area_registry.async_get(hass)
return registry.async_get_area(area_name) or registry.async_get_area_by_name(
area_name area_name
) )
assert area is not None, f"No area named {area_name}"
if area is not None:
# Filter by area
states_and_entities = [
(state, entity)
for state, entity in states_and_entities
if (entity is not None) and (entity.area_id == area.id)
]
if name is not None:
# Filter by name
name = name.casefold()
for state, entity in states_and_entities:
if _has_name(state, entity, name):
yield state
break
else:
# Not filtered by name
for state, _entity in states_and_entities:
yield state
@callback @callback
@ -229,81 +299,96 @@ class ServiceIntentHandler(IntentHandler):
hass = intent_obj.hass hass = intent_obj.hass
slots = self.async_validate_slots(intent_obj.slots) slots = self.async_validate_slots(intent_obj.slots)
if "area" in slots: name: str | None = slots.get("name", {}).get("value")
# Entities in an area if name == "all":
area_name = slots["area"]["value"] # Don't match on name if targeting all entities
area = async_match_area(hass, area_name) name = None
assert area is not None
assert area.id is not None
# Optional domain filter # Look up area first to fail early
area_name = slots.get("area", {}).get("value")
area: area_registry.AreaEntry | None = None
if area_name is not None:
areas = area_registry.async_get(hass)
area = areas.async_get_area(area_name) or areas.async_get_area_by_name(
area_name
)
if area is None:
raise IntentHandleError(f"No area named {area_name}")
# Optional domain/device class filters.
# Convert to sets for speed.
domains: set[str] | None = None domains: set[str] | None = None
device_classes: set[str] | None = None
if "domain" in slots: if "domain" in slots:
domains = set(slots["domain"]["value"]) domains = set(slots["domain"]["value"])
# Optional device class filter
device_classes: set[str] | None = None
if "device_class" in slots: if "device_class" in slots:
device_classes = set(slots["device_class"]["value"]) device_classes = set(slots["device_class"]["value"])
success_results = [ states = list(
async_match_states(
hass,
name=name,
area=area,
domains=domains,
device_classes=device_classes,
)
)
if not states:
raise IntentHandleError("No entities matched")
response = await self.async_handle_states(intent_obj, states, area)
return response
async def async_handle_states(
self,
intent_obj: Intent,
states: list[State],
area: area_registry.AreaEntry | None = None,
) -> IntentResponse:
"""Complete action on matched entity states."""
assert states
success_results: list[IntentResponseTarget] = []
response = intent_obj.create_response()
if area is not None:
success_results.append(
IntentResponseTarget( IntentResponseTarget(
type=IntentResponseTargetType.AREA, name=area.name, id=area.id type=IntentResponseTargetType.AREA, name=area.name, id=area.id
) )
] )
speech_name = area.name
else:
speech_name = states[0].name
service_coros = [] service_coros = []
registry = entity_registry.async_get(hass) for state in states:
for entity_entry in entity_registry.async_entries_for_area( service_coros.append(self.async_call_service(intent_obj, state))
registry, area.id
):
if entity_entry.entity_category:
# Skip diagnostic entities
continue
if domains and (entity_entry.domain not in domains):
# Skip entity not in the domain
continue
if device_classes and (entity_entry.device_class not in device_classes):
# Skip entity with wrong device class
continue
service_coros.append(
hass.services.async_call(
self.domain,
self.service,
{ATTR_ENTITY_ID: entity_entry.entity_id},
context=intent_obj.context,
)
)
state = hass.states.get(entity_entry.entity_id)
assert state is not None
success_results.append( success_results.append(
IntentResponseTarget( IntentResponseTarget(
type=IntentResponseTargetType.ENTITY, type=IntentResponseTargetType.ENTITY,
name=state.name, name=state.name,
id=entity_entry.entity_id, id=state.entity_id,
), ),
) )
if not service_coros:
raise IntentHandleError("No entities matched")
# Handle service calls in parallel. # Handle service calls in parallel.
# We will need to handle partial failures here. # We will need to handle partial failures here.
await asyncio.gather(*service_coros) await asyncio.gather(*service_coros)
response = intent_obj.create_response()
response.async_set_speech(self.speech.format(area.name))
response.async_set_results( response.async_set_results(
success_results=success_results, success_results=success_results,
) )
else: response.async_set_speech(self.speech.format(speech_name))
# Single entity
state = async_match_state(hass, slots["name"]["value"])
return response
async def async_call_service(self, intent_obj: Intent, state: State) -> None:
"""Call service on entity."""
hass = intent_obj.hass
await hass.services.async_call( await hass.services.async_call(
self.domain, self.domain,
self.service, self.service,
@ -311,20 +396,6 @@ class ServiceIntentHandler(IntentHandler):
context=intent_obj.context, context=intent_obj.context,
) )
response = intent_obj.create_response()
response.async_set_speech(self.speech.format(state.name))
response.async_set_results(
success_results=[
IntentResponseTarget(
type=IntentResponseTargetType.ENTITY,
name=state.name,
id=state.entity_id,
),
],
)
return response
class IntentCategory(Enum): class IntentCategory(Enum):
"""Category of an intent.""" """Category of an intent."""

View File

@ -2,7 +2,7 @@
from homeassistant.components import light from homeassistant.components import light
from homeassistant.components.light import ATTR_SUPPORTED_COLOR_MODES, ColorMode, intent from homeassistant.components.light import ATTR_SUPPORTED_COLOR_MODES, ColorMode, intent
from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_ON from homeassistant.const import ATTR_ENTITY_ID, SERVICE_TURN_ON
from homeassistant.helpers.intent import IntentHandleError, async_handle from homeassistant.helpers.intent import async_handle
from tests.common import async_mock_service from tests.common import async_mock_service
@ -40,17 +40,16 @@ async def test_intent_set_color_tests_feature(hass):
calls = async_mock_service(hass, light.DOMAIN, light.SERVICE_TURN_ON) calls = async_mock_service(hass, light.DOMAIN, light.SERVICE_TURN_ON)
await intent.async_setup_intents(hass) await intent.async_setup_intents(hass)
try: response = await async_handle(
await async_handle(
hass, hass,
"test", "test",
intent.INTENT_SET, intent.INTENT_SET,
{"name": {"value": "Hello"}, "color": {"value": "blue"}}, {"name": {"value": "Hello"}, "color": {"value": "blue"}},
) )
assert False, "handling intent should have raised"
except IntentHandleError as err:
assert str(err) == "Entity hello does not support changing colors"
# Response should contain one failed target
assert len(response.success_results) == 0
assert len(response.failed_results) == 1
assert len(calls) == 0 assert len(calls) == 0

View File

@ -3,9 +3,15 @@
import pytest import pytest
import voluptuous as vol import voluptuous as vol
from homeassistant.components.switch import SwitchDeviceClass
from homeassistant.const import ATTR_FRIENDLY_NAME from homeassistant.const import ATTR_FRIENDLY_NAME
from homeassistant.core import State from homeassistant.core import State
from homeassistant.helpers import config_validation as cv, entity_registry, intent from homeassistant.helpers import (
area_registry,
config_validation as cv,
entity_registry,
intent,
)
class MockIntentHandler(intent.IntentHandler): class MockIntentHandler(intent.IntentHandler):
@ -16,25 +22,74 @@ class MockIntentHandler(intent.IntentHandler):
self.slot_schema = slot_schema self.slot_schema = slot_schema
async def test_async_match_state(hass): async def test_async_match_states(hass):
"""Test async_match_state helper.""" """Test async_match_state helper."""
areas = area_registry.async_get(hass)
area_kitchen = areas.async_get_or_create("kitchen")
area_bedroom = areas.async_get_or_create("bedroom")
state1 = State( state1 = State(
"light.kitchen", "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"} "light.kitchen", "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
) )
state2 = State( state2 = State(
"switch.kitchen", "on", attributes={ATTR_FRIENDLY_NAME: "kitchen switch"} "switch.bedroom", "on", attributes={ATTR_FRIENDLY_NAME: "bedroom switch"}
) )
registry = entity_registry.async_get(hass)
registry.async_get_or_create( # Put entities into different areas
"switch", "demo", "1234", suggested_object_id="kitchen" entities = entity_registry.async_get(hass)
entities.async_get_or_create("light", "demo", "1234", suggested_object_id="kitchen")
entities.async_update_entity(state1.entity_id, area_id=area_kitchen.id)
entities.async_get_or_create(
"switch", "demo", "1234", suggested_object_id="bedroom"
)
entities.async_update_entity(
state2.entity_id,
area_id=area_bedroom.id,
device_class=SwitchDeviceClass.OUTLET,
aliases={"kill switch"},
) )
registry.async_update_entity(state2.entity_id, aliases={"kill switch"})
state = intent.async_match_state(hass, "kitchen light", [state1, state2]) # Match on name
assert state is state1 assert [state1] == list(
intent.async_match_states(hass, name="kitchen light", states=[state1, state2])
)
state = intent.async_match_state(hass, "kill switch", [state1, state2]) # Test alias
assert state is state2 assert [state2] == list(
intent.async_match_states(hass, name="kill switch", states=[state1, state2])
)
# Name + area
assert [state1] == list(
intent.async_match_states(
hass, name="kitchen light", area_name="kitchen", states=[state1, state2]
)
)
# Wrong area
assert not list(
intent.async_match_states(
hass, name="kitchen light", area_name="bedroom", states=[state1, state2]
)
)
# Domain + area
assert [state2] == list(
intent.async_match_states(
hass, domains={"switch"}, area_name="bedroom", states=[state1, state2]
)
)
# Device class + area
assert [state2] == list(
intent.async_match_states(
hass,
device_classes={SwitchDeviceClass.OUTLET},
area_name="bedroom",
states=[state1, state2],
)
)
def test_async_validate_slots(): def test_async_validate_slots():