diff --git a/homeassistant/components/energy/validate.py b/homeassistant/components/energy/validate.py index e48a576f44e..3baae348770 100644 --- a/homeassistant/components/energy/validate.py +++ b/homeassistant/components/energy/validate.py @@ -10,6 +10,7 @@ from homeassistant.components import recorder, sensor from homeassistant.const import ( ATTR_DEVICE_CLASS, ENERGY_KILO_WATT_HOUR, + ENERGY_MEGA_WATT_HOUR, ENERGY_WATT_HOUR, STATE_UNAVAILABLE, STATE_UNKNOWN, @@ -23,7 +24,11 @@ from .const import DOMAIN ENERGY_USAGE_DEVICE_CLASSES = (sensor.SensorDeviceClass.ENERGY,) ENERGY_USAGE_UNITS = { - sensor.SensorDeviceClass.ENERGY: (ENERGY_KILO_WATT_HOUR, ENERGY_WATT_HOUR) + sensor.SensorDeviceClass.ENERGY: ( + ENERGY_KILO_WATT_HOUR, + ENERGY_MEGA_WATT_HOUR, + ENERGY_WATT_HOUR, + ) } ENERGY_PRICE_UNITS = tuple( f"/{unit}" for units in ENERGY_USAGE_UNITS.values() for unit in units diff --git a/tests/components/energy/test_validate.py b/tests/components/energy/test_validate.py index e802688daaf..fe71663d41b 100644 --- a/tests/components/energy/test_validate.py +++ b/tests/components/energy/test_validate.py @@ -4,6 +4,11 @@ from unittest.mock import patch import pytest from homeassistant.components.energy import async_get_manager, validate +from homeassistant.const import ( + ENERGY_KILO_WATT_HOUR, + ENERGY_MEGA_WATT_HOUR, + ENERGY_WATT_HOUR, +) from homeassistant.helpers.json import JSON_DUMP from homeassistant.setup import async_setup_component @@ -60,16 +65,18 @@ async def test_validation_empty_config(hass): @pytest.mark.parametrize( - "state_class, extra", + "state_class, energy_unit, extra", [ - ("total_increasing", {}), - ("total", {}), - ("total", {"last_reset": "abc"}), - ("measurement", {"last_reset": "abc"}), + ("total_increasing", ENERGY_KILO_WATT_HOUR, {}), + ("total_increasing", ENERGY_MEGA_WATT_HOUR, {}), + ("total_increasing", ENERGY_WATT_HOUR, {}), + ("total", ENERGY_KILO_WATT_HOUR, {}), + ("total", ENERGY_KILO_WATT_HOUR, {"last_reset": "abc"}), + ("measurement", ENERGY_KILO_WATT_HOUR, {"last_reset": "abc"}), ], ) async def test_validation( - hass, mock_energy_manager, mock_get_metadata, state_class, extra + hass, mock_energy_manager, mock_get_metadata, state_class, energy_unit, extra ): """Test validating success.""" for key in ("device_cons", "battery_import", "battery_export", "solar_production"): @@ -78,7 +85,7 @@ async def test_validation( "123", { "device_class": "energy", - "unit_of_measurement": "kWh", + "unit_of_measurement": energy_unit, "state_class": state_class, **extra, },