diff --git a/homeassistant/components/mqtt/config_flow.py b/homeassistant/components/mqtt/config_flow.py index bb884d6392f..b41e549093d 100644 --- a/homeassistant/components/mqtt/config_flow.py +++ b/homeassistant/components/mqtt/config_flow.py @@ -39,6 +39,7 @@ from homeassistant.components.light import ( from homeassistant.components.sensor import ( CONF_STATE_CLASS, DEVICE_CLASS_UNITS, + STATE_CLASS_UNITS, SensorDeviceClass, SensorStateClass, ) @@ -640,6 +641,13 @@ def validate_sensor_platform_config( ): errors[CONF_UNIT_OF_MEASUREMENT] = "invalid_uom" + if ( + (state_class := config.get(CONF_STATE_CLASS)) is not None + and state_class in STATE_CLASS_UNITS + and config.get(CONF_UNIT_OF_MEASUREMENT) not in STATE_CLASS_UNITS[state_class] + ): + errors[CONF_UNIT_OF_MEASUREMENT] = "invalid_uom_for_state_class" + return errors @@ -676,11 +684,19 @@ class PlatformField: @callback def unit_of_measurement_selector(user_data: dict[str, Any | None]) -> Selector: """Return a context based unit of measurement selector.""" + + if (state_class := user_data.get(CONF_STATE_CLASS)) in STATE_CLASS_UNITS: + return SelectSelector( + SelectSelectorConfig( + options=[str(uom) for uom in STATE_CLASS_UNITS[state_class]], + sort=True, + custom_value=True, + ) + ) + if ( - user_data is None - or (device_class := user_data.get(CONF_DEVICE_CLASS)) is None - or device_class not in DEVICE_CLASS_UNITS - ): + device_class := user_data.get(CONF_DEVICE_CLASS) + ) is None or device_class not in DEVICE_CLASS_UNITS: return TEXT_SELECTOR return SelectSelector( SelectSelectorConfig( diff --git a/homeassistant/components/mqtt/sensor.py b/homeassistant/components/mqtt/sensor.py index b27ef68368a..46d475fcee8 100644 --- a/homeassistant/components/mqtt/sensor.py +++ b/homeassistant/components/mqtt/sensor.py @@ -14,6 +14,7 @@ from homeassistant.components.sensor import ( DEVICE_CLASS_UNITS, DEVICE_CLASSES_SCHEMA, ENTITY_ID_FORMAT, + STATE_CLASS_UNITS, STATE_CLASSES_SCHEMA, RestoreSensor, SensorDeviceClass, @@ -117,6 +118,17 @@ def validate_sensor_state_and_device_class_config(config: ConfigType) -> ConfigT f"got `{CONF_DEVICE_CLASS}` '{device_class}'" ) + if ( + (state_class := config.get(CONF_STATE_CLASS)) is not None + and state_class in STATE_CLASS_UNITS + and (unit_of_measurement := config.get(CONF_UNIT_OF_MEASUREMENT)) + not in STATE_CLASS_UNITS[state_class] + ): + raise vol.Invalid( + f"The unit of measurement '{unit_of_measurement}' is not valid " + f"together with state class '{state_class}'" + ) + if (device_class := config.get(CONF_DEVICE_CLASS)) is None or ( unit_of_measurement := config.get(CONF_UNIT_OF_MEASUREMENT) ) is None: diff --git a/homeassistant/components/mqtt/strings.json b/homeassistant/components/mqtt/strings.json index 8fc97362857..9bc6df1b633 100644 --- a/homeassistant/components/mqtt/strings.json +++ b/homeassistant/components/mqtt/strings.json @@ -644,6 +644,7 @@ "invalid_template": "Invalid template", "invalid_supported_color_modes": "Invalid supported color modes selection", "invalid_uom": "The unit of measurement \"{unit_of_measurement}\" is not supported by the selected device class, please either remove the device class, select a device class which supports \"{unit_of_measurement}\", or pick a supported unit of measurement from the list", + "invalid_uom_for_state_class": "The unit of measurement \"{unit_of_measurement}\" is not supported by the selected state class, please either remove the state class, select a state class which supports \"{unit_of_measurement}\", or pick a supported unit of measurement from the list", "invalid_url": "Invalid URL", "last_reset_not_with_state_class_total": "The last reset value template option should be used with state class 'Total' only", "max_below_min_kelvin": "Max Kelvin value should be greater than min Kelvin value", diff --git a/tests/components/mqtt/test_config_flow.py b/tests/components/mqtt/test_config_flow.py index a43617badb0..e30aa5d50d6 100644 --- a/tests/components/mqtt/test_config_flow.py +++ b/tests/components/mqtt/test_config_flow.py @@ -3038,7 +3038,15 @@ async def test_migrate_of_incompatible_config_entry( { "state_class": "measurement", }, - (), + ( + ( + { + "state_class": "measurement_angle", + "unit_of_measurement": "deg", + }, + {"unit_of_measurement": "invalid_uom_for_state_class"}, + ), + ), { "state_topic": "test-topic", }, diff --git a/tests/components/mqtt/test_sensor.py b/tests/components/mqtt/test_sensor.py index 0bafacfed26..ea1b7e186e2 100644 --- a/tests/components/mqtt/test_sensor.py +++ b/tests/components/mqtt/test_sensor.py @@ -995,6 +995,32 @@ async def test_invalid_state_class( assert "expected SensorStateClass or one of" in caplog.text +@pytest.mark.parametrize( + "hass_config", + [ + { + mqtt.DOMAIN: { + sensor.DOMAIN: { + "name": "test", + "state_topic": "test-topic", + "state_class": "measurement_angle", + "unit_of_measurement": "deg", + } + } + } + ], +) +async def test_invalid_state_class_with_unit_of_measurement( + mqtt_mock_entry: MqttMockHAClientGenerator, caplog: pytest.LogCaptureFixture +) -> None: + """Test state_class option with invalid unit of measurement.""" + assert await mqtt_mock_entry() + assert ( + "The unit of measurement 'deg' is not valid together with state class 'measurement_angle'" + in caplog.text + ) + + @pytest.mark.parametrize( ("hass_config", "error_logged"), [