From ba814af701e339596ee056e66a36224083dbe70b Mon Sep 17 00:00:00 2001 From: Diogo Gomes Date: Mon, 21 Mar 2022 07:21:26 +0000 Subject: [PATCH] Adopt SelectEntity in Utility Meter (#55690) Co-authored-by: Erik Montnemery --- .../components/utility_meter/__init__.py | 114 ++-------- .../components/utility_meter/const.py | 4 + .../components/utility_meter/select.py | 204 ++++++++++++++++++ .../components/utility_meter/services.yaml | 4 +- tests/components/utility_meter/test_init.py | 107 ++++++++- tests/components/utility_meter/test_sensor.py | 14 +- 6 files changed, 329 insertions(+), 118 deletions(-) create mode 100644 homeassistant/components/utility_meter/select.py diff --git a/homeassistant/components/utility_meter/__init__.py b/homeassistant/components/utility_meter/__init__.py index 525b4f3b43c..6cd6cc46933 100644 --- a/homeassistant/components/utility_meter/__init__.py +++ b/homeassistant/components/utility_meter/__init__.py @@ -5,18 +5,16 @@ import logging from croniter import croniter import voluptuous as vol +from homeassistant.components.select import DOMAIN as SELECT_DOMAIN from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN from homeassistant.const import CONF_NAME from homeassistant.core import HomeAssistant from homeassistant.helpers import discovery import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.entity_component import EntityComponent -from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.typing import ConfigType from .const import ( - ATTR_TARIFF, CONF_CRON_PATTERN, CONF_METER, CONF_METER_DELTA_VALUES, @@ -27,22 +25,15 @@ from .const import ( CONF_TARIFF, CONF_TARIFF_ENTITY, CONF_TARIFFS, + DATA_LEGACY_COMPONENT, DATA_TARIFF_SENSORS, DATA_UTILITY, DOMAIN, METER_TYPES, - SERVICE_RESET, - SERVICE_SELECT_NEXT_TARIFF, - SERVICE_SELECT_TARIFF, - SIGNAL_RESET_METER, ) _LOGGER = logging.getLogger(__name__) -TARIFF_ICON = "mdi:clock-outline" - -ATTR_TARIFFS = "tariffs" - DEFAULT_OFFSET = timedelta(hours=0) @@ -105,9 +96,9 @@ CONFIG_SCHEMA = vol.Schema( async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up an Utility Meter.""" - component = EntityComponent(_LOGGER, DOMAIN, hass) + hass.data[DATA_LEGACY_COMPONENT] = EntityComponent(_LOGGER, DOMAIN, hass) + hass.data[DATA_UTILITY] = {} - register_services = False for meter, conf in config[DOMAIN].items(): _LOGGER.debug("Setup %s.%s", DOMAIN, meter) @@ -129,11 +120,18 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: ) else: # create tariff selection - await component.async_add_entities( - [TariffSelect(meter, list(conf[CONF_TARIFFS]))] + hass.async_create_task( + discovery.async_load_platform( + hass, + SELECT_DOMAIN, + DOMAIN, + {CONF_METER: meter, CONF_TARIFFS: conf[CONF_TARIFFS]}, + config, + ) ) + hass.data[DATA_UTILITY][meter][CONF_TARIFF_ENTITY] = "{}.{}".format( - DOMAIN, meter + SELECT_DOMAIN, meter ) # add one meter for each tariff @@ -151,89 +149,5 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: hass, SENSOR_DOMAIN, DOMAIN, tariff_confs, config ) ) - register_services = True - - if register_services: - component.async_register_entity_service(SERVICE_RESET, {}, "async_reset_meters") - - component.async_register_entity_service( - SERVICE_SELECT_TARIFF, - {vol.Required(ATTR_TARIFF): cv.string}, - "async_select_tariff", - ) - - component.async_register_entity_service( - SERVICE_SELECT_NEXT_TARIFF, {}, "async_next_tariff" - ) return True - - -class TariffSelect(RestoreEntity): - """Representation of a Tariff selector.""" - - def __init__(self, name, tariffs): - """Initialize a tariff selector.""" - self._name = name - self._current_tariff = None - self._tariffs = tariffs - self._icon = TARIFF_ICON - - async def async_added_to_hass(self): - """Run when entity about to be added.""" - await super().async_added_to_hass() - - state = await self.async_get_last_state() - if not state or state.state not in self._tariffs: - self._current_tariff = self._tariffs[0] - else: - self._current_tariff = state.state - - @property - def should_poll(self): - """If entity should be polled.""" - return False - - @property - def name(self): - """Return the name of the select input.""" - return self._name - - @property - def icon(self): - """Return the icon to be used for this entity.""" - return self._icon - - @property - def state(self): - """Return the state of the component.""" - return self._current_tariff - - @property - def extra_state_attributes(self): - """Return the state attributes.""" - return {ATTR_TARIFFS: self._tariffs} - - async def async_reset_meters(self): - """Reset all sensors of this meter.""" - _LOGGER.debug("reset meter %s", self.entity_id) - async_dispatcher_send(self.hass, SIGNAL_RESET_METER, self.entity_id) - - async def async_select_tariff(self, tariff): - """Select new option.""" - if tariff not in self._tariffs: - _LOGGER.warning( - "Invalid tariff: %s (possible tariffs: %s)", - tariff, - ", ".join(self._tariffs), - ) - return - self._current_tariff = tariff - self.async_write_ha_state() - - async def async_next_tariff(self): - """Offset current index.""" - current_index = self._tariffs.index(self._current_tariff) - new_index = (current_index + 1) % len(self._tariffs) - self._current_tariff = self._tariffs[new_index] - self.async_write_ha_state() diff --git a/homeassistant/components/utility_meter/const.py b/homeassistant/components/utility_meter/const.py index 097496e231d..2bac649aace 100644 --- a/homeassistant/components/utility_meter/const.py +++ b/homeassistant/components/utility_meter/const.py @@ -1,6 +1,8 @@ """Constants for the utility meter component.""" DOMAIN = "utility_meter" +TARIFF_ICON = "mdi:clock-outline" + QUARTER_HOURLY = "quarter-hourly" HOURLY = "hourly" DAILY = "daily" @@ -23,6 +25,7 @@ METER_TYPES = [ DATA_UTILITY = "utility_meter_data" DATA_TARIFF_SENSORS = "utility_meter_sensors" +DATA_LEGACY_COMPONENT = "utility_meter_legacy_component" CONF_METER = "meter" CONF_SOURCE_SENSOR = "source" @@ -37,6 +40,7 @@ CONF_TARIFF_ENTITY = "tariff_entity" CONF_CRON_PATTERN = "cron" ATTR_TARIFF = "tariff" +ATTR_TARIFFS = "tariffs" ATTR_VALUE = "value" ATTR_CRON_PATTERN = "cron pattern" diff --git a/homeassistant/components/utility_meter/select.py b/homeassistant/components/utility_meter/select.py new file mode 100644 index 00000000000..b523d72aba4 --- /dev/null +++ b/homeassistant/components/utility_meter/select.py @@ -0,0 +1,204 @@ +"""Support for tariff selection.""" +from __future__ import annotations + +import logging + +import voluptuous as vol + +from homeassistant.components.select import SelectEntity +from homeassistant.components.select.const import ( + ATTR_OPTION, + ATTR_OPTIONS, + DOMAIN as SELECT_DOMAIN, + SERVICE_SELECT_OPTION, +) +from homeassistant.const import ATTR_ENTITY_ID, ATTR_FRIENDLY_NAME, STATE_UNAVAILABLE +from homeassistant.core import Event, callback, split_entity_id +from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.dispatcher import async_dispatcher_send +from homeassistant.helpers.entity import Entity +from homeassistant.helpers.event import async_track_state_change_event +from homeassistant.helpers.restore_state import RestoreEntity + +from .const import ( + ATTR_TARIFF, + ATTR_TARIFFS, + CONF_METER, + CONF_TARIFFS, + DATA_LEGACY_COMPONENT, + DOMAIN, + SERVICE_RESET, + SERVICE_SELECT_NEXT_TARIFF, + SERVICE_SELECT_TARIFF, + SIGNAL_RESET_METER, + TARIFF_ICON, +) + +_LOGGER = logging.getLogger(__name__) + + +async def async_setup_platform(hass, conf, async_add_entities, discovery_info=None): + """Set up the utility meter select.""" + legacy_component = hass.data[DATA_LEGACY_COMPONENT] + async_add_entities( + [ + TariffSelect( + discovery_info[CONF_METER], + discovery_info[CONF_TARIFFS], + legacy_component.async_add_entities, + ) + ] + ) + + async def async_reset_meters(service_call): + """Reset all sensors of a meter.""" + entity_id = service_call.data["entity_id"] + + domain = split_entity_id(entity_id)[0] + if domain == DOMAIN: + for entity in legacy_component.entities: + if entity_id == entity.entity_id: + _LOGGER.debug( + "forward reset meter from %s to %s", + entity_id, + entity.tracked_entity_id, + ) + entity_id = entity.tracked_entity_id + + _LOGGER.debug("reset meter %s", entity_id) + async_dispatcher_send(hass, SIGNAL_RESET_METER, entity_id) + + hass.services.async_register( + DOMAIN, + SERVICE_RESET, + async_reset_meters, + vol.Schema({ATTR_ENTITY_ID: cv.entity_id}), + ) + + legacy_component.async_register_entity_service( + SERVICE_SELECT_TARIFF, + {vol.Required(ATTR_TARIFF): cv.string}, + "async_select_tariff", + ) + + legacy_component.async_register_entity_service( + SERVICE_SELECT_NEXT_TARIFF, {}, "async_next_tariff" + ) + + +class TariffSelect(SelectEntity, RestoreEntity): + """Representation of a Tariff selector.""" + + def __init__(self, name, tariffs, add_legacy_entities): + """Initialize a tariff selector.""" + self._attr_name = name + self._current_tariff = None + self._tariffs = tariffs + self._attr_icon = TARIFF_ICON + self._attr_should_poll = False + self._add_legacy_entities = add_legacy_entities + + @property + def options(self): + """Return the available tariffs.""" + return self._tariffs + + @property + def current_option(self): + """Return current tariff.""" + return self._current_tariff + + async def async_added_to_hass(self): + """Run when entity about to be added.""" + await super().async_added_to_hass() + + await self._add_legacy_entities([LegacyTariffSelect(self.entity_id)]) + + state = await self.async_get_last_state() + if not state or state.state not in self._tariffs: + self._current_tariff = self._tariffs[0] + else: + self._current_tariff = state.state + + async def async_select_option(self, option: str) -> None: + """Select new tariff (option).""" + self._current_tariff = option + self.async_write_ha_state() + + +class LegacyTariffSelect(Entity): + """Backwards compatibility for deprecated utility_meter select entity.""" + + def __init__(self, tracked_entity_id): + """Initialize the entity.""" + self._attr_icon = TARIFF_ICON + # Set name to influence enity_id + self._attr_name = split_entity_id(tracked_entity_id)[1] + self.tracked_entity_id = tracked_entity_id + + @callback + def async_state_changed_listener(self, event: Event | None = None) -> None: + """Handle child updates.""" + if ( + state := self.hass.states.get(self.tracked_entity_id) + ) is None or state.state == STATE_UNAVAILABLE: + self._attr_available = False + return + + self._attr_available = True + + self._attr_name = state.attributes.get(ATTR_FRIENDLY_NAME) + self._attr_state = state.state + self._attr_extra_state_attributes = { + ATTR_TARIFFS: state.attributes.get(ATTR_OPTIONS) + } + + async def async_added_to_hass(self) -> None: + """Register callbacks.""" + + @callback + def _async_state_changed_listener(event: Event | None = None) -> None: + """Handle child updates.""" + self.async_state_changed_listener(event) + self.async_write_ha_state() + + self.async_on_remove( + async_track_state_change_event( + self.hass, [self.tracked_entity_id], _async_state_changed_listener + ) + ) + + # Call once on adding + _async_state_changed_listener() + + async def async_select_tariff(self, tariff): + """Select new option.""" + _LOGGER.warning( + "The 'utility_meter.select_tariff' service has been deprecated and will " + "be removed in HA Core 2022.7. Please use 'select.select_option' instead", + ) + await self.hass.services.async_call( + SELECT_DOMAIN, + SERVICE_SELECT_OPTION, + {ATTR_ENTITY_ID: self.tracked_entity_id, ATTR_OPTION: tariff}, + blocking=True, + context=self._context, + ) + + async def async_next_tariff(self): + """Offset current index.""" + _LOGGER.warning( + "The 'utility_meter.next_tariff' service has been deprecated and will " + "be removed in HA Core 2022.7. Please use 'select.select_option' instead", + ) + if ( + not self.available + or (state := self.hass.states.get(self.tracked_entity_id)) is None + ): + return + tariffs = state.attributes.get(ATTR_OPTIONS) + current_tariff = state.state + current_index = tariffs.index(current_tariff) + new_index = (current_index + 1) % len(tariffs) + + await self.async_select_tariff(tariffs[new_index]) diff --git a/homeassistant/components/utility_meter/services.yaml b/homeassistant/components/utility_meter/services.yaml index c3f95d22175..800e001f6ff 100644 --- a/homeassistant/components/utility_meter/services.yaml +++ b/homeassistant/components/utility_meter/services.yaml @@ -2,10 +2,10 @@ reset: name: Reset - description: Resets the counter of a utility meter. + description: Resets all counters of an utility meter. target: entity: - domain: utility_meter + domain: select next_tariff: name: Next Tariff diff --git a/tests/components/utility_meter/test_init.py b/tests/components/utility_meter/test_init.py index 3297c696ca1..8b600865d44 100644 --- a/tests/components/utility_meter/test_init.py +++ b/tests/components/utility_meter/test_init.py @@ -2,12 +2,16 @@ from datetime import timedelta from unittest.mock import patch +from homeassistant.components.select.const import ( + DOMAIN as SELECT_DOMAIN, + SERVICE_SELECT_OPTION, +) from homeassistant.components.utility_meter.const import ( - ATTR_TARIFF, DOMAIN, SERVICE_RESET, SERVICE_SELECT_NEXT_TARIFF, SERVICE_SELECT_TARIFF, + SIGNAL_RESET_METER, ) import homeassistant.components.utility_meter.sensor as um_sensor from homeassistant.const import ( @@ -19,6 +23,7 @@ from homeassistant.const import ( Platform, ) from homeassistant.core import State +from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util @@ -39,7 +44,7 @@ async def test_restore_state(hass): hass, [ State( - "utility_meter.energy_bill", + "select.energy_bill", "midpeak", ), ], @@ -50,7 +55,7 @@ async def test_restore_state(hass): await hass.async_block_till_done() # restore from cache - state = hass.states.get("utility_meter.energy_bill") + state = hass.states.get("select.energy_bill") assert state.state == "midpeak" @@ -98,7 +103,7 @@ async def test_services(hass): state = hass.states.get("sensor.energy_bill_offpeak") assert state.state == "0" - # Next tariff + # Next tariff - only supported on legacy entity data = {ATTR_ENTITY_ID: "utility_meter.energy_bill"} await hass.services.async_call(DOMAIN, SERVICE_SELECT_NEXT_TARIFF, data) await hass.async_block_till_done() @@ -120,15 +125,15 @@ async def test_services(hass): assert state.state == "1" # Change tariff - data = {ATTR_ENTITY_ID: "utility_meter.energy_bill", ATTR_TARIFF: "wrong_tariff"} - await hass.services.async_call(DOMAIN, SERVICE_SELECT_TARIFF, data) + data = {ATTR_ENTITY_ID: "select.energy_bill", "option": "wrong_tariff"} + await hass.services.async_call(SELECT_DOMAIN, SERVICE_SELECT_OPTION, data) await hass.async_block_till_done() # Inexisting tariff, ignoring - assert hass.states.get("utility_meter.energy_bill").state != "wrong_tariff" + assert hass.states.get("select.energy_bill").state != "wrong_tariff" - data = {ATTR_ENTITY_ID: "utility_meter.energy_bill", ATTR_TARIFF: "peak"} - await hass.services.async_call(DOMAIN, SERVICE_SELECT_TARIFF, data) + data = {ATTR_ENTITY_ID: "select.energy_bill", "option": "peak"} + await hass.services.async_call(SELECT_DOMAIN, SERVICE_SELECT_OPTION, data) await hass.async_block_till_done() now += timedelta(seconds=10) @@ -148,7 +153,7 @@ async def test_services(hass): assert state.state == "1" # Reset meters - data = {ATTR_ENTITY_ID: "utility_meter.energy_bill"} + data = {ATTR_ENTITY_ID: "select.energy_bill"} await hass.services.async_call(DOMAIN, SERVICE_RESET, data) await hass.async_block_till_done() @@ -240,3 +245,85 @@ async def test_bad_cron(hass, legacy_patchable_time): async def test_setup_missing_discovery(hass): """Test setup with configuration missing discovery_info.""" assert not await um_sensor.async_setup_platform(hass, {CONF_PLATFORM: DOMAIN}, None) + + +async def test_legacy_support(hass): + """Test legacy entity support.""" + config = { + "utility_meter": { + "energy_bill": { + "source": "sensor.energy", + "cycle": "hourly", + "tariffs": ["peak", "offpeak"], + }, + } + } + + assert await async_setup_component(hass, DOMAIN, config) + assert await async_setup_component(hass, Platform.SENSOR, config) + await hass.async_block_till_done() + + select_state = hass.states.get("select.energy_bill") + legacy_state = hass.states.get("utility_meter.energy_bill") + + assert select_state.state == legacy_state.state == "peak" + select_attributes = select_state.attributes + legacy_attributes = legacy_state.attributes + assert select_attributes.keys() == { + "friendly_name", + "icon", + "options", + } + assert legacy_attributes.keys() == {"friendly_name", "icon", "tariffs"} + assert select_attributes["friendly_name"] == legacy_attributes["friendly_name"] + assert select_attributes["icon"] == legacy_attributes["icon"] + assert select_attributes["options"] == legacy_attributes["tariffs"] + + # Change tariff on the select + data = {ATTR_ENTITY_ID: "select.energy_bill", "option": "offpeak"} + await hass.services.async_call(SELECT_DOMAIN, SERVICE_SELECT_OPTION, data) + await hass.async_block_till_done() + + select_state = hass.states.get("select.energy_bill") + legacy_state = hass.states.get("utility_meter.energy_bill") + assert select_state.state == legacy_state.state == "offpeak" + + # Change tariff on the legacy entity + data = {ATTR_ENTITY_ID: "utility_meter.energy_bill", "tariff": "offpeak"} + await hass.services.async_call(DOMAIN, SERVICE_SELECT_TARIFF, data) + await hass.async_block_till_done() + + select_state = hass.states.get("select.energy_bill") + legacy_state = hass.states.get("utility_meter.energy_bill") + assert select_state.state == legacy_state.state == "offpeak" + + # Cycle tariffs on the select - not supported + data = {ATTR_ENTITY_ID: "select.energy_bill"} + await hass.services.async_call(DOMAIN, SERVICE_SELECT_NEXT_TARIFF, data) + await hass.async_block_till_done() + + select_state = hass.states.get("select.energy_bill") + legacy_state = hass.states.get("utility_meter.energy_bill") + assert select_state.state == legacy_state.state == "offpeak" + + # Cycle tariffs on the legacy entity + data = {ATTR_ENTITY_ID: "utility_meter.energy_bill"} + await hass.services.async_call(DOMAIN, SERVICE_SELECT_NEXT_TARIFF, data) + await hass.async_block_till_done() + + select_state = hass.states.get("select.energy_bill") + legacy_state = hass.states.get("utility_meter.energy_bill") + assert select_state.state == legacy_state.state == "peak" + + # Reset the legacy entity + reset_calls = [] + + def async_reset_meter(entity_id): + reset_calls.append(entity_id) + + async_dispatcher_connect(hass, SIGNAL_RESET_METER, async_reset_meter) + + data = {ATTR_ENTITY_ID: "utility_meter.energy_bill"} + await hass.services.async_call(DOMAIN, SERVICE_RESET, data) + await hass.async_block_till_done() + assert reset_calls == ["select.energy_bill"] diff --git a/tests/components/utility_meter/test_sensor.py b/tests/components/utility_meter/test_sensor.py index fbaf795f9e2..df8e1c5e6a1 100644 --- a/tests/components/utility_meter/test_sensor.py +++ b/tests/components/utility_meter/test_sensor.py @@ -3,20 +3,22 @@ from contextlib import contextmanager from datetime import timedelta from unittest.mock import patch +from homeassistant.components.select.const import ( + DOMAIN as SELECT_DOMAIN, + SERVICE_SELECT_OPTION, +) from homeassistant.components.sensor import ( ATTR_STATE_CLASS, SensorDeviceClass, SensorStateClass, ) from homeassistant.components.utility_meter.const import ( - ATTR_TARIFF, ATTR_VALUE, DAILY, DOMAIN, HOURLY, QUARTER_HOURLY, SERVICE_CALIBRATE_METER, - SERVICE_SELECT_TARIFF, ) from homeassistant.components.utility_meter.sensor import ( ATTR_LAST_RESET, @@ -117,9 +119,9 @@ async def test_state(hass): assert state.attributes.get("status") == PAUSED await hass.services.async_call( - DOMAIN, - SERVICE_SELECT_TARIFF, - {ATTR_ENTITY_ID: "utility_meter.energy_bill", ATTR_TARIFF: "offpeak"}, + SELECT_DOMAIN, + SERVICE_SELECT_OPTION, + {ATTR_ENTITY_ID: "select.energy_bill", "option": "offpeak"}, blocking=True, ) @@ -343,7 +345,7 @@ async def test_restore_state(hass): hass.bus.async_fire(EVENT_HOMEASSISTANT_START) await hass.async_block_till_done() - state = hass.states.get("utility_meter.energy_bill") + state = hass.states.get("select.energy_bill") assert state.state == "onpeak" state = hass.states.get("sensor.energy_bill_onpeak")