Enhance automation integration to use new features in script helper (#37479)

This commit is contained in:
Phil Bruckner 2020-07-05 09:25:15 -05:00 committed by GitHub
parent c3b5bf7437
commit f7c4900d5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 161 additions and 74 deletions

View File

@ -9,10 +9,13 @@ import voluptuous as vol
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_ENTITY_ID,
ATTR_NAME, ATTR_NAME,
CONF_ALIAS,
CONF_DEVICE_ID, CONF_DEVICE_ID,
CONF_ENTITY_ID, CONF_ENTITY_ID,
CONF_ID, CONF_ID,
CONF_MODE,
CONF_PLATFORM, CONF_PLATFORM,
CONF_QUEUE_SIZE,
CONF_ZONE, CONF_ZONE,
EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STARTED,
SERVICE_RELOAD, SERVICE_RELOAD,
@ -23,11 +26,12 @@ from homeassistant.const import (
) )
from homeassistant.core import Context, CoreState, HomeAssistant, callback from homeassistant.core import Context, CoreState, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import condition, extract_domain_configs, script from homeassistant.helpers import condition, extract_domain_configs
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.entity import ToggleEntity
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.script import SCRIPT_BASE_SCHEMA, Script, validate_queue_size
from homeassistant.helpers.service import async_register_admin_service from homeassistant.helpers.service import async_register_admin_service
from homeassistant.helpers.typing import TemplateVarsType from homeassistant.helpers.typing import TemplateVarsType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
@ -41,7 +45,6 @@ ENTITY_ID_FORMAT = DOMAIN + ".{}"
GROUP_NAME_ALL_AUTOMATIONS = "all automations" GROUP_NAME_ALL_AUTOMATIONS = "all automations"
CONF_ALIAS = "alias"
CONF_DESCRIPTION = "description" CONF_DESCRIPTION = "description"
CONF_HIDE_ENTITY = "hide_entity" CONF_HIDE_ENTITY = "hide_entity"
@ -96,7 +99,7 @@ _CONDITION_SCHEMA = vol.All(cv.ensure_list, [cv.CONDITION_SCHEMA])
PLATFORM_SCHEMA = vol.All( PLATFORM_SCHEMA = vol.All(
cv.deprecated(CONF_HIDE_ENTITY, invalidation_version="0.110"), cv.deprecated(CONF_HIDE_ENTITY, invalidation_version="0.110"),
vol.Schema( SCRIPT_BASE_SCHEMA.extend(
{ {
# str on purpose # str on purpose
CONF_ID: str, CONF_ID: str,
@ -109,6 +112,7 @@ PLATFORM_SCHEMA = vol.All(
vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA, vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA,
} }
), ),
validate_queue_size,
) )
@ -389,7 +393,7 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
try: try:
await self.action_script.async_run(variables, trigger_context) await self.action_script.async_run(variables, trigger_context)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
pass _LOGGER.exception("While executing automation %s", self.entity_id)
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self):
"""Remove listeners when removing automation from Home Assistant.""" """Remove listeners when removing automation from Home Assistant."""
@ -498,8 +502,13 @@ async def _async_process_config(hass, config, component):
initial_state = config_block.get(CONF_INITIAL_STATE) initial_state = config_block.get(CONF_INITIAL_STATE)
action_script = script.Script( action_script = Script(
hass, config_block.get(CONF_ACTION, {}), name, logger=_LOGGER hass,
config_block[CONF_ACTION],
name,
script_mode=config_block[CONF_MODE],
queue_size=config_block.get(CONF_QUEUE_SIZE, 0),
logger=_LOGGER,
) )
if CONF_CONDITION in config_block: if CONF_CONDITION in config_block:

View File

@ -1,6 +1,7 @@
"""Config validation helper for the automation integration.""" """Config validation helper for the automation integration."""
import asyncio import asyncio
import importlib import importlib
import logging
import voluptuous as vol import voluptuous as vol
@ -8,13 +9,20 @@ from homeassistant.components.device_automation.exceptions import (
InvalidDeviceAutomationConfig, InvalidDeviceAutomationConfig,
) )
from homeassistant.config import async_log_exception, config_without_domain from homeassistant.config import async_log_exception, config_without_domain
from homeassistant.const import CONF_PLATFORM from homeassistant.const import CONF_ALIAS, CONF_ID, CONF_MODE, CONF_PLATFORM
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import condition, config_per_platform, script from homeassistant.helpers import condition, config_per_platform
from homeassistant.helpers.script import (
SCRIPT_MODE_LEGACY,
async_validate_action_config,
warn_deprecated_legacy,
)
from homeassistant.loader import IntegrationNotFound from homeassistant.loader import IntegrationNotFound
from . import CONF_ACTION, CONF_CONDITION, CONF_TRIGGER, DOMAIN, PLATFORM_SCHEMA from . import CONF_ACTION, CONF_CONDITION, CONF_TRIGGER, DOMAIN, PLATFORM_SCHEMA
_LOGGER = logging.getLogger(__name__)
# mypy: allow-untyped-calls, allow-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs
# mypy: no-check-untyped-defs, no-warn-return-any # mypy: no-check-untyped-defs, no-warn-return-any
@ -44,10 +52,7 @@ async def async_validate_config_item(hass, config, full_config=None):
) )
config[CONF_ACTION] = await asyncio.gather( config[CONF_ACTION] = await asyncio.gather(
*[ *[async_validate_action_config(hass, action) for action in config[CONF_ACTION]]
script.async_validate_action_config(hass, action)
for action in config[CONF_ACTION]
]
) )
return config return config
@ -69,24 +74,54 @@ async def _try_async_validate_config_item(hass, config, full_config=None):
return config return config
def _deprecated_legacy_mode(config):
legacy_names = []
legacy_unnamed_found = False
for cfg in config[DOMAIN]:
mode = cfg.get(CONF_MODE)
if mode is None:
cfg[CONF_MODE] = SCRIPT_MODE_LEGACY
name = cfg.get(CONF_ID) or cfg.get(CONF_ALIAS)
if name:
legacy_names.append(name)
else:
legacy_unnamed_found = True
if legacy_names or legacy_unnamed_found:
msgs = []
if legacy_unnamed_found:
msgs.append("unnamed automations")
if legacy_names:
if len(legacy_names) == 1:
base_msg = "this automation"
else:
base_msg = "these automations"
msgs.append(f"{base_msg}: {', '.join(legacy_names)}")
warn_deprecated_legacy(_LOGGER, " and ".join(msgs))
return config
async def async_validate_config(hass, config): async def async_validate_config(hass, config):
"""Validate config.""" """Validate config."""
validated_automations = await asyncio.gather( automations = list(
filter(
lambda x: x is not None,
await asyncio.gather(
*( *(
_try_async_validate_config_item(hass, p_config, config) _try_async_validate_config_item(hass, p_config, config)
for _, p_config in config_per_platform(config, DOMAIN) for _, p_config in config_per_platform(config, DOMAIN)
) )
),
)
) )
automations = [
validated_automation
for validated_automation in validated_automations
if validated_automation is not None
]
# Create a copy of the configuration with all config for current # Create a copy of the configuration with all config for current
# component removed and add validated config back in. # component removed and add validated config back in.
config = config_without_domain(config, DOMAIN) config = config_without_domain(config, DOMAIN)
config[DOMAIN] = automations config[DOMAIN] = automations
_deprecated_legacy_mode(config)
return config return config

View File

@ -11,6 +11,7 @@ from homeassistant.const import (
CONF_ALIAS, CONF_ALIAS,
CONF_ICON, CONF_ICON,
CONF_MODE, CONF_MODE,
CONF_QUEUE_SIZE,
SERVICE_RELOAD, SERVICE_RELOAD,
SERVICE_TOGGLE, SERVICE_TOGGLE,
SERVICE_TURN_OFF, SERVICE_TURN_OFF,
@ -23,11 +24,11 @@ from homeassistant.helpers.config_validation import make_entity_service_schema
from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.entity import ToggleEntity
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.script import ( from homeassistant.helpers.script import (
DEFAULT_QUEUE_MAX, SCRIPT_BASE_SCHEMA,
SCRIPT_MODE_CHOICES,
SCRIPT_MODE_LEGACY, SCRIPT_MODE_LEGACY,
SCRIPT_MODE_QUEUE,
Script, Script,
validate_queue_size,
warn_deprecated_legacy,
) )
from homeassistant.helpers.service import async_set_service_schema from homeassistant.helpers.service import async_set_service_schema
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
@ -44,7 +45,6 @@ CONF_DESCRIPTION = "description"
CONF_EXAMPLE = "example" CONF_EXAMPLE = "example"
CONF_FIELDS = "fields" CONF_FIELDS = "fields"
CONF_SEQUENCE = "sequence" CONF_SEQUENCE = "sequence"
CONF_QUEUE_MAX = "queue_size"
ENTITY_ID_FORMAT = DOMAIN + ".{}" ENTITY_ID_FORMAT = DOMAIN + ".{}"
@ -59,33 +59,12 @@ def _deprecated_legacy_mode(config):
legacy_scripts.append(object_id) legacy_scripts.append(object_id)
cfg[CONF_MODE] = SCRIPT_MODE_LEGACY cfg[CONF_MODE] = SCRIPT_MODE_LEGACY
if legacy_scripts: if legacy_scripts:
_LOGGER.warning( warn_deprecated_legacy(_LOGGER, f"script(s): {', '.join(legacy_scripts)}")
"Script behavior has changed. "
"To continue using previous behavior, which is now deprecated, "
"add '%s: %s' to script(s): %s.",
CONF_MODE,
SCRIPT_MODE_LEGACY,
", ".join(legacy_scripts),
)
return config return config
def _queue_max(config): SCRIPT_ENTRY_SCHEMA = vol.All(
for object_id, cfg in config.items(): SCRIPT_BASE_SCHEMA.extend(
mode = cfg[CONF_MODE]
queue_max = cfg.get(CONF_QUEUE_MAX)
if mode == SCRIPT_MODE_QUEUE:
if queue_max is None:
cfg[CONF_QUEUE_MAX] = DEFAULT_QUEUE_MAX
elif queue_max is not None:
raise vol.Invalid(
f"{CONF_QUEUE_MAX} not valid with {mode} {CONF_MODE} "
f"for script '{object_id}'"
)
return config
SCRIPT_ENTRY_SCHEMA = vol.Schema(
{ {
vol.Optional(CONF_ALIAS): cv.string, vol.Optional(CONF_ALIAS): cv.string,
vol.Optional(CONF_ICON): cv.icon, vol.Optional(CONF_ICON): cv.icon,
@ -97,17 +76,15 @@ SCRIPT_ENTRY_SCHEMA = vol.Schema(
vol.Optional(CONF_EXAMPLE): cv.string, vol.Optional(CONF_EXAMPLE): cv.string,
} }
}, },
vol.Optional(CONF_MODE): vol.In(SCRIPT_MODE_CHOICES),
vol.Optional(CONF_QUEUE_MAX): vol.All(vol.Coerce(int), vol.Range(min=2)),
} }
),
validate_queue_size,
) )
CONFIG_SCHEMA = vol.Schema( CONFIG_SCHEMA = vol.Schema(
{ {
DOMAIN: vol.All( DOMAIN: vol.All(
cv.schema_with_slug_keys(SCRIPT_ENTRY_SCHEMA), cv.schema_with_slug_keys(SCRIPT_ENTRY_SCHEMA), _deprecated_legacy_mode
_deprecated_legacy_mode,
_queue_max,
) )
}, },
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
@ -271,7 +248,7 @@ async def _async_process_config(hass, config, component):
cfg.get(CONF_ICON), cfg.get(CONF_ICON),
cfg[CONF_SEQUENCE], cfg[CONF_SEQUENCE],
cfg[CONF_MODE], cfg[CONF_MODE],
cfg.get(CONF_QUEUE_MAX, 0), cfg.get(CONF_QUEUE_SIZE, 0),
) )
) )
@ -303,7 +280,7 @@ class ScriptEntity(ToggleEntity):
icon = None icon = None
def __init__(self, hass, object_id, name, icon, sequence, mode, queue_max): def __init__(self, hass, object_id, name, icon, sequence, mode, queue_size):
"""Initialize the script.""" """Initialize the script."""
self.object_id = object_id self.object_id = object_id
self.icon = icon self.icon = icon
@ -314,7 +291,7 @@ class ScriptEntity(ToggleEntity):
name, name,
self.async_change_listener, self.async_change_listener,
mode, mode,
queue_max, queue_size,
logging.getLogger(f"{__name__}.{object_id}"), logging.getLogger(f"{__name__}.{object_id}"),
) )
self._changed = asyncio.Event() self._changed = asyncio.Event()

View File

@ -134,6 +134,7 @@ CONF_PREFIX = "prefix"
CONF_PROFILE_NAME = "profile_name" CONF_PROFILE_NAME = "profile_name"
CONF_PROTOCOL = "protocol" CONF_PROTOCOL = "protocol"
CONF_PROXY_SSL = "proxy_ssl" CONF_PROXY_SSL = "proxy_ssl"
CONF_QUEUE_SIZE = "queue_size"
CONF_QUOTE = "quote" CONF_QUOTE = "quote"
CONF_RADIUS = "radius" CONF_RADIUS = "radius"
CONF_RECIPIENT = "recipient" CONF_RECIPIENT = "recipient"

View File

@ -24,6 +24,8 @@ from homeassistant.const import (
CONF_EVENT, CONF_EVENT,
CONF_EVENT_DATA, CONF_EVENT_DATA,
CONF_EVENT_DATA_TEMPLATE, CONF_EVENT_DATA_TEMPLATE,
CONF_MODE,
CONF_QUEUE_SIZE,
CONF_SCENE, CONF_SCENE,
CONF_TIMEOUT, CONF_TIMEOUT,
CONF_WAIT_TEMPLATE, CONF_WAIT_TEMPLATE,
@ -72,11 +74,42 @@ SCRIPT_MODE_CHOICES = [
] ]
DEFAULT_SCRIPT_MODE = SCRIPT_MODE_LEGACY DEFAULT_SCRIPT_MODE = SCRIPT_MODE_LEGACY
DEFAULT_QUEUE_MAX = 10 DEFAULT_QUEUE_SIZE = 10
_LOG_EXCEPTION = logging.ERROR + 1 _LOG_EXCEPTION = logging.ERROR + 1
_TIMEOUT_MSG = "Timeout reached, abort script." _TIMEOUT_MSG = "Timeout reached, abort script."
SCRIPT_BASE_SCHEMA = vol.Schema(
{
vol.Optional(CONF_MODE): vol.In(SCRIPT_MODE_CHOICES),
vol.Optional(CONF_QUEUE_SIZE): vol.All(vol.Coerce(int), vol.Range(min=2)),
}
)
def warn_deprecated_legacy(logger, msg):
"""Warn about deprecated legacy mode."""
logger.warning(
"Script behavior has changed. "
"To continue using previous behavior, which is now deprecated, "
"add '%s: %s' to %s.",
CONF_MODE,
SCRIPT_MODE_LEGACY,
msg,
)
def validate_queue_size(config):
"""Validate queue_size option."""
mode = config.get(CONF_MODE, DEFAULT_SCRIPT_MODE)
queue_size = config.get(CONF_QUEUE_SIZE)
if mode == SCRIPT_MODE_QUEUE:
if queue_size is None:
config[CONF_QUEUE_SIZE] = DEFAULT_QUEUE_SIZE
elif queue_size is not None:
raise vol.Invalid(f"{CONF_QUEUE_SIZE} not valid with {mode} {CONF_MODE}")
return config
async def async_validate_action_config( async def async_validate_action_config(
hass: HomeAssistant, config: ConfigType hass: HomeAssistant, config: ConfigType
@ -673,7 +706,7 @@ class Script:
name: Optional[str] = None, name: Optional[str] = None,
change_listener: Optional[Callable[..., Any]] = None, change_listener: Optional[Callable[..., Any]] = None,
script_mode: str = DEFAULT_SCRIPT_MODE, script_mode: str = DEFAULT_SCRIPT_MODE,
queue_max: int = DEFAULT_QUEUE_MAX, queue_size: int = DEFAULT_QUEUE_SIZE,
logger: Optional[logging.Logger] = None, logger: Optional[logging.Logger] = None,
log_exceptions: bool = True, log_exceptions: bool = True,
) -> None: ) -> None:
@ -702,7 +735,7 @@ class Script:
self._runs: List[_ScriptRunBase] = [] self._runs: List[_ScriptRunBase] = []
if script_mode == SCRIPT_MODE_QUEUE: if script_mode == SCRIPT_MODE_QUEUE:
self._queue_max = queue_max self._queue_size = queue_size
self._queue_len = 0 self._queue_len = 0
self._queue_lck = asyncio.Lock() self._queue_lck = asyncio.Lock()
self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {} self._config_cache: Dict[Set[Tuple], Callable[..., bool]] = {}
@ -806,7 +839,7 @@ class Script:
self._queue_len, self._queue_len,
"s" if self._queue_len > 1 else "", "s" if self._queue_len > 1 else "",
) )
if self._queue_len >= self._queue_max: if self._queue_len >= self._queue_size:
raise QueueFull raise QueueFull
if self.is_legacy: if self.is_legacy:

View File

@ -1070,3 +1070,35 @@ async def test_logbook_humanify_automation_triggered_event(hass):
assert event2["domain"] == "automation" assert event2["domain"] == "automation"
assert event2["message"] == "has been triggered" assert event2["message"] == "has been triggered"
assert event2["entity_id"] == "automation.bye" assert event2["entity_id"] == "automation.bye"
async def test_invalid_config(hass):
"""Test invalid config."""
with assert_setup_component(0, automation.DOMAIN):
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"mode": "parallel",
"queue_size": 5,
"trigger": {"platform": "event", "event_type": "test_event"},
"action": [],
}
},
)
async def test_config_legacy(hass, caplog):
"""Test config defaulting to legacy mode."""
assert await async_setup_component(
hass,
automation.DOMAIN,
{
automation.DOMAIN: {
"trigger": {"platform": "event", "event_type": "test_event"},
"action": [],
}
},
)
assert "To continue using previous behavior, which is now deprecated" in caplog.text