From c72ac87c73dc0ea8ce2c1f279c640872c12dbaac Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 8 Oct 2019 05:10:21 +0200 Subject: [PATCH] Fix device condition scaffold (#27300) --- .../integration/device_condition.py | 35 ++++++++++++++----- .../tests/test_device_condition.py | 7 ++-- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/script/scaffold/templates/device_condition/integration/device_condition.py b/script/scaffold/templates/device_condition/integration/device_condition.py index d19fa8817a0..e9c7e55e23a 100644 --- a/script/scaffold/templates/device_condition/integration/device_condition.py +++ b/script/scaffold/templates/device_condition/integration/device_condition.py @@ -4,24 +4,28 @@ import voluptuous as vol from homeassistant.const import ( ATTR_ENTITY_ID, + CONF_CONDITION, CONF_DOMAIN, CONF_TYPE, - CONF_PLATFORM, CONF_DEVICE_ID, CONF_ENTITY_ID, + STATE_OFF, STATE_ON, ) from homeassistant.core import HomeAssistant -from homeassistant.helpers import condition, entity_registry +from homeassistant.helpers import condition, config_validation as cv, entity_registry from homeassistant.helpers.typing import ConfigType, TemplateVarsType from homeassistant.helpers.config_validation import DEVICE_CONDITION_BASE_SCHEMA from . import DOMAIN # TODO specify your supported condition types. -CONDITION_TYPES = {"is_on"} +CONDITION_TYPES = {"is_on", "is_off"} CONDITION_SCHEMA = DEVICE_CONDITION_BASE_SCHEMA.extend( - {vol.Required(CONF_TYPE): vol.In(CONDITION_TYPES)} + { + vol.Required(CONF_ENTITY_ID): cv.entity_id, + vol.Required(CONF_TYPE): vol.In(CONDITION_TYPES), + } ) @@ -39,13 +43,22 @@ async def async_get_conditions(hass: HomeAssistant, device_id: str) -> List[str] # TODO add your own conditions. conditions.append( { - CONF_PLATFORM: "device", + CONF_CONDITION: "device", CONF_DEVICE_ID: device_id, CONF_DOMAIN: DOMAIN, CONF_ENTITY_ID: entry.entity_id, CONF_TYPE: "is_on", } ) + conditions.append( + { + CONF_CONDITION: "device", + CONF_DEVICE_ID: device_id, + CONF_DOMAIN: DOMAIN, + CONF_ENTITY_ID: entry.entity_id, + CONF_TYPE: "is_off", + } + ) return conditions @@ -56,9 +69,13 @@ def async_condition_from_config( """Create a function to test a device condition.""" if config_validation: config = CONDITION_SCHEMA(config) + if config[CONF_TYPE] == "is_on": + state = STATE_ON + else: + state = STATE_OFF - def test_is_on(hass: HomeAssistant, variables: TemplateVarsType) -> bool: - """Test if an entity is on.""" - return condition.state(hass, config[ATTR_ENTITY_ID], STATE_ON) + def test_is_state(hass: HomeAssistant, variables: TemplateVarsType) -> bool: + """Test if an entity is a certain state.""" + return condition.state(hass, config[ATTR_ENTITY_ID], state) - return test_is_on + return test_is_state diff --git a/script/scaffold/templates/device_condition/tests/test_device_condition.py b/script/scaffold/templates/device_condition/tests/test_device_condition.py index d9cef083510..1ae4df5f1b7 100644 --- a/script/scaffold/templates/device_condition/tests/test_device_condition.py +++ b/script/scaffold/templates/device_condition/tests/test_device_condition.py @@ -1,7 +1,7 @@ """The tests for NEW_NAME device conditions.""" import pytest -from homeassistant.components.switch import DOMAIN +from homeassistant.components.NEW_DOMAIN import DOMAIN from homeassistant.const import STATE_ON, STATE_OFF from homeassistant.setup import async_setup_component import homeassistant.components.automation as automation @@ -9,6 +9,7 @@ from homeassistant.helpers import device_registry from tests.common import ( MockConfigEntry, + assert_lists_same, async_mock_service, mock_device_registry, mock_registry, @@ -35,7 +36,7 @@ def calls(hass): async def test_get_conditions(hass, device_reg, entity_reg): - """Test we get the expected conditions from a switch.""" + """Test we get the expected conditions from a NEW_DOMAIN.""" config_entry = MockConfigEntry(domain="test", data={}) config_entry.add_to_hass(hass) device_entry = device_reg.async_get_or_create( @@ -60,7 +61,7 @@ async def test_get_conditions(hass, device_reg, entity_reg): }, ] conditions = await async_get_device_automations(hass, "condition", device_entry.id) - assert conditions == expected_conditions + assert_lists_same(conditions, expected_conditions) async def test_if_state(hass, calls):