Enable strict typing for prometheus (#108025)

This commit is contained in:
Marc Mueller 2024-01-15 09:07:12 +01:00 committed by GitHub
parent 45acd56861
commit 5bde007048
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 106 additions and 64 deletions

View File

@ -320,6 +320,7 @@ homeassistant.components.plugwise.*
homeassistant.components.poolsense.* homeassistant.components.poolsense.*
homeassistant.components.powerwall.* homeassistant.components.powerwall.*
homeassistant.components.private_ble_device.* homeassistant.components.private_ble_device.*
homeassistant.components.prometheus.*
homeassistant.components.proximity.* homeassistant.components.proximity.*
homeassistant.components.prusalink.* homeassistant.components.prusalink.*
homeassistant.components.pure_energie.* homeassistant.components.pure_energie.*

View File

@ -1,10 +1,15 @@
"""Support for Prometheus metrics export.""" """Support for Prometheus metrics export."""
from __future__ import annotations
from collections.abc import Callable
from contextlib import suppress from contextlib import suppress
import logging import logging
import string import string
from typing import Any, TypeVar, cast
from aiohttp import web from aiohttp import web
import prometheus_client import prometheus_client
from prometheus_client.metrics import MetricWrapperBase
import voluptuous as vol import voluptuous as vol
from homeassistant import core as hacore from homeassistant import core as hacore
@ -40,15 +45,20 @@ from homeassistant.const import (
STATE_UNKNOWN, STATE_UNKNOWN,
UnitOfTemperature, UnitOfTemperature,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant, State
from homeassistant.helpers import entityfilter, state as state_helper from homeassistant.helpers import entityfilter, state as state_helper
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_registry import EVENT_ENTITY_REGISTRY_UPDATED from homeassistant.helpers.entity_registry import (
EVENT_ENTITY_REGISTRY_UPDATED,
EventEntityRegistryUpdatedData,
)
from homeassistant.helpers.entity_values import EntityValues from homeassistant.helpers.entity_values import EntityValues
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.event import EventStateChangedData
from homeassistant.helpers.typing import ConfigType, EventType
from homeassistant.util.dt import as_timestamp from homeassistant.util.dt import as_timestamp
from homeassistant.util.unit_conversion import TemperatureConverter from homeassistant.util.unit_conversion import TemperatureConverter
_MetricBaseT = TypeVar("_MetricBaseT", bound=MetricWrapperBase)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
API_ENDPOINT = "/api/prometheus" API_ENDPOINT = "/api/prometheus"
@ -97,12 +107,12 @@ def setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Activate Prometheus component.""" """Activate Prometheus component."""
hass.http.register_view(PrometheusView(config[DOMAIN][CONF_REQUIRES_AUTH])) hass.http.register_view(PrometheusView(config[DOMAIN][CONF_REQUIRES_AUTH]))
conf = config[DOMAIN] conf: dict[str, Any] = config[DOMAIN]
entity_filter = conf[CONF_FILTER] entity_filter: entityfilter.EntityFilter = conf[CONF_FILTER]
namespace = conf.get(CONF_PROM_NAMESPACE) namespace: str = conf[CONF_PROM_NAMESPACE]
climate_units = hass.config.units.temperature_unit climate_units = hass.config.units.temperature_unit
override_metric = conf.get(CONF_OVERRIDE_METRIC) override_metric: str | None = conf.get(CONF_OVERRIDE_METRIC)
default_metric = conf.get(CONF_DEFAULT_METRIC) default_metric: str | None = conf.get(CONF_DEFAULT_METRIC)
component_config = EntityValues( component_config = EntityValues(
conf[CONF_COMPONENT_CONFIG], conf[CONF_COMPONENT_CONFIG],
conf[CONF_COMPONENT_CONFIG_DOMAIN], conf[CONF_COMPONENT_CONFIG_DOMAIN],
@ -118,9 +128,10 @@ def setup(hass: HomeAssistant, config: ConfigType) -> bool:
default_metric, default_metric,
) )
hass.bus.listen(EVENT_STATE_CHANGED, metrics.handle_state_changed_event) hass.bus.listen(EVENT_STATE_CHANGED, metrics.handle_state_changed_event) # type: ignore[arg-type]
hass.bus.listen( hass.bus.listen(
EVENT_ENTITY_REGISTRY_UPDATED, metrics.handle_entity_registry_updated EVENT_ENTITY_REGISTRY_UPDATED,
metrics.handle_entity_registry_updated, # type: ignore[arg-type]
) )
for state in hass.states.all(): for state in hass.states.all():
@ -135,19 +146,21 @@ class PrometheusMetrics:
def __init__( def __init__(
self, self,
entity_filter, entity_filter: entityfilter.EntityFilter,
namespace, namespace: str,
climate_units, climate_units: UnitOfTemperature,
component_config, component_config: EntityValues,
override_metric, override_metric: str | None,
default_metric, default_metric: str | None,
): ) -> None:
"""Initialize Prometheus Metrics.""" """Initialize Prometheus Metrics."""
self._component_config = component_config self._component_config = component_config
self._override_metric = override_metric self._override_metric = override_metric
self._default_metric = default_metric self._default_metric = default_metric
self._filter = entity_filter self._filter = entity_filter
self._sensor_metric_handlers = [ self._sensor_metric_handlers: list[
Callable[[State, str | None], str | None]
] = [
self._sensor_override_component_metric, self._sensor_override_component_metric,
self._sensor_override_metric, self._sensor_override_metric,
self._sensor_timestamp_metric, self._sensor_timestamp_metric,
@ -160,10 +173,12 @@ class PrometheusMetrics:
self.metrics_prefix = f"{namespace}_" self.metrics_prefix = f"{namespace}_"
else: else:
self.metrics_prefix = "" self.metrics_prefix = ""
self._metrics = {} self._metrics: dict[str, MetricWrapperBase] = {}
self._climate_units = climate_units self._climate_units = climate_units
def handle_state_changed_event(self, event): def handle_state_changed_event(
self, event: EventType[EventStateChangedData]
) -> None:
"""Handle new messages from the bus.""" """Handle new messages from the bus."""
if (state := event.data.get("new_state")) is None: if (state := event.data.get("new_state")) is None:
return return
@ -179,7 +194,7 @@ class PrometheusMetrics:
self.handle_state(state) self.handle_state(state)
def handle_state(self, state): def handle_state(self, state: State) -> None:
"""Add/update a state in Prometheus.""" """Add/update a state in Prometheus."""
entity_id = state.entity_id entity_id = state.entity_id
_LOGGER.debug("Handling state update for %s", entity_id) _LOGGER.debug("Handling state update for %s", entity_id)
@ -212,20 +227,22 @@ class PrometheusMetrics:
) )
last_updated_time_seconds.labels(**labels).set(state.last_updated.timestamp()) last_updated_time_seconds.labels(**labels).set(state.last_updated.timestamp())
def handle_entity_registry_updated(self, event): def handle_entity_registry_updated(
self, event: EventType[EventEntityRegistryUpdatedData]
) -> None:
"""Listen for deleted, disabled or renamed entities and remove them from the Prometheus Registry.""" """Listen for deleted, disabled or renamed entities and remove them from the Prometheus Registry."""
if (action := event.data.get("action")) in (None, "create"): if event.data["action"] in (None, "create"):
return return
entity_id = event.data.get("entity_id") entity_id = event.data.get("entity_id")
_LOGGER.debug("Handling entity update for %s", entity_id) _LOGGER.debug("Handling entity update for %s", entity_id)
metrics_entity_id = None metrics_entity_id: str | None = None
if action == "remove": if event.data["action"] == "remove":
metrics_entity_id = entity_id metrics_entity_id = entity_id
elif action == "update": elif event.data["action"] == "update":
changes = event.data.get("changes") changes = event.data["changes"]
if "entity_id" in changes: if "entity_id" in changes:
metrics_entity_id = changes["entity_id"] metrics_entity_id = changes["entity_id"]
@ -235,10 +252,14 @@ class PrometheusMetrics:
if metrics_entity_id: if metrics_entity_id:
self._remove_labelsets(metrics_entity_id) self._remove_labelsets(metrics_entity_id)
def _remove_labelsets(self, entity_id, friendly_name=None): def _remove_labelsets(
self, entity_id: str, friendly_name: str | None = None
) -> None:
"""Remove labelsets matching the given entity id from all metrics.""" """Remove labelsets matching the given entity id from all metrics."""
for _, metric in self._metrics.items(): for _, metric in self._metrics.items():
for sample in metric.collect()[0].samples: for sample in cast(list[prometheus_client.Metric], metric.collect())[
0
].samples:
if sample.labels["entity"] == entity_id and ( if sample.labels["entity"] == entity_id and (
not friendly_name or sample.labels["friendly_name"] == friendly_name not friendly_name or sample.labels["friendly_name"] == friendly_name
): ):
@ -250,7 +271,7 @@ class PrometheusMetrics:
with suppress(KeyError): with suppress(KeyError):
metric.remove(*sample.labels.values()) metric.remove(*sample.labels.values())
def _handle_attributes(self, state): def _handle_attributes(self, state: State) -> None:
for key, value in state.attributes.items(): for key, value in state.attributes.items():
metric = self._metric( metric = self._metric(
f"{state.domain}_attr_{key.lower()}", f"{state.domain}_attr_{key.lower()}",
@ -264,13 +285,19 @@ class PrometheusMetrics:
except (ValueError, TypeError): except (ValueError, TypeError):
pass pass
def _metric(self, metric, factory, documentation, extra_labels=None): def _metric(
self,
metric: str,
factory: type[_MetricBaseT],
documentation: str,
extra_labels: list[str] | None = None,
) -> _MetricBaseT:
labels = ["entity", "friendly_name", "domain"] labels = ["entity", "friendly_name", "domain"]
if extra_labels is not None: if extra_labels is not None:
labels.extend(extra_labels) labels.extend(extra_labels)
try: try:
return self._metrics[metric] return cast(_MetricBaseT, self._metrics[metric])
except KeyError: except KeyError:
full_metric_name = self._sanitize_metric_name( full_metric_name = self._sanitize_metric_name(
f"{self.metrics_prefix}{metric}" f"{self.metrics_prefix}{metric}"
@ -281,7 +308,7 @@ class PrometheusMetrics:
labels, labels,
registry=prometheus_client.REGISTRY, registry=prometheus_client.REGISTRY,
) )
return self._metrics[metric] return cast(_MetricBaseT, self._metrics[metric])
@staticmethod @staticmethod
def _sanitize_metric_name(metric: str) -> str: def _sanitize_metric_name(metric: str) -> str:
@ -298,7 +325,7 @@ class PrometheusMetrics:
) )
@staticmethod @staticmethod
def state_as_number(state): def state_as_number(state: State) -> float:
"""Return a state casted to a float.""" """Return a state casted to a float."""
try: try:
if state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.TIMESTAMP: if state.attributes.get(ATTR_DEVICE_CLASS) == SensorDeviceClass.TIMESTAMP:
@ -311,14 +338,14 @@ class PrometheusMetrics:
return value return value
@staticmethod @staticmethod
def _labels(state): def _labels(state: State) -> dict[str, Any]:
return { return {
"entity": state.entity_id, "entity": state.entity_id,
"domain": state.domain, "domain": state.domain,
"friendly_name": state.attributes.get(ATTR_FRIENDLY_NAME), "friendly_name": state.attributes.get(ATTR_FRIENDLY_NAME),
} }
def _battery(self, state): def _battery(self, state: State) -> None:
if (battery_level := state.attributes.get(ATTR_BATTERY_LEVEL)) is not None: if (battery_level := state.attributes.get(ATTR_BATTERY_LEVEL)) is not None:
metric = self._metric( metric = self._metric(
"battery_level_percent", "battery_level_percent",
@ -331,7 +358,7 @@ class PrometheusMetrics:
except ValueError: except ValueError:
pass pass
def _handle_binary_sensor(self, state): def _handle_binary_sensor(self, state: State) -> None:
metric = self._metric( metric = self._metric(
"binary_sensor_state", "binary_sensor_state",
prometheus_client.Gauge, prometheus_client.Gauge,
@ -340,7 +367,7 @@ class PrometheusMetrics:
value = self.state_as_number(state) value = self.state_as_number(state)
metric.labels(**self._labels(state)).set(value) metric.labels(**self._labels(state)).set(value)
def _handle_input_boolean(self, state): def _handle_input_boolean(self, state: State) -> None:
metric = self._metric( metric = self._metric(
"input_boolean_state", "input_boolean_state",
prometheus_client.Gauge, prometheus_client.Gauge,
@ -349,7 +376,7 @@ class PrometheusMetrics:
value = self.state_as_number(state) value = self.state_as_number(state)
metric.labels(**self._labels(state)).set(value) metric.labels(**self._labels(state)).set(value)
def _numeric_handler(self, state, domain, title): def _numeric_handler(self, state: State, domain: str, title: str) -> None:
if unit := self._unit_string(state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)): if unit := self._unit_string(state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)):
metric = self._metric( metric = self._metric(
f"{domain}_state_{unit}", f"{domain}_state_{unit}",
@ -374,13 +401,13 @@ class PrometheusMetrics:
) )
metric.labels(**self._labels(state)).set(value) metric.labels(**self._labels(state)).set(value)
def _handle_input_number(self, state): def _handle_input_number(self, state: State) -> None:
self._numeric_handler(state, "input_number", "input number") self._numeric_handler(state, "input_number", "input number")
def _handle_number(self, state): def _handle_number(self, state: State) -> None:
self._numeric_handler(state, "number", "number") self._numeric_handler(state, "number", "number")
def _handle_device_tracker(self, state): def _handle_device_tracker(self, state: State) -> None:
metric = self._metric( metric = self._metric(
"device_tracker_state", "device_tracker_state",
prometheus_client.Gauge, prometheus_client.Gauge,
@ -389,14 +416,14 @@ class PrometheusMetrics:
value = self.state_as_number(state) value = self.state_as_number(state)
metric.labels(**self._labels(state)).set(value) metric.labels(**self._labels(state)).set(value)
def _handle_person(self, state): def _handle_person(self, state: State) -> None:
metric = self._metric( metric = self._metric(
"person_state", prometheus_client.Gauge, "State of the person (0/1)" "person_state", prometheus_client.Gauge, "State of the person (0/1)"
) )
value = self.state_as_number(state) value = self.state_as_number(state)
metric.labels(**self._labels(state)).set(value) metric.labels(**self._labels(state)).set(value)
def _handle_cover(self, state): def _handle_cover(self, state: State) -> None:
metric = self._metric( metric = self._metric(
"cover_state", "cover_state",
prometheus_client.Gauge, prometheus_client.Gauge,
@ -428,7 +455,7 @@ class PrometheusMetrics:
) )
tilt_position_metric.labels(**self._labels(state)).set(float(tilt_position)) tilt_position_metric.labels(**self._labels(state)).set(float(tilt_position))
def _handle_light(self, state): def _handle_light(self, state: State) -> None:
metric = self._metric( metric = self._metric(
"light_brightness_percent", "light_brightness_percent",
prometheus_client.Gauge, prometheus_client.Gauge,
@ -446,14 +473,16 @@ class PrometheusMetrics:
except ValueError: except ValueError:
pass pass
def _handle_lock(self, state): def _handle_lock(self, state: State) -> None:
metric = self._metric( metric = self._metric(
"lock_state", prometheus_client.Gauge, "State of the lock (0/1)" "lock_state", prometheus_client.Gauge, "State of the lock (0/1)"
) )
value = self.state_as_number(state) value = self.state_as_number(state)
metric.labels(**self._labels(state)).set(value) metric.labels(**self._labels(state)).set(value)
def _handle_climate_temp(self, state, attr, metric_name, metric_description): def _handle_climate_temp(
self, state: State, attr: str, metric_name: str, metric_description: str
) -> None:
if (temp := state.attributes.get(attr)) is not None: if (temp := state.attributes.get(attr)) is not None:
if self._climate_units == UnitOfTemperature.FAHRENHEIT: if self._climate_units == UnitOfTemperature.FAHRENHEIT:
temp = TemperatureConverter.convert( temp = TemperatureConverter.convert(
@ -466,7 +495,7 @@ class PrometheusMetrics:
) )
metric.labels(**self._labels(state)).set(temp) metric.labels(**self._labels(state)).set(temp)
def _handle_climate(self, state): def _handle_climate(self, state: State) -> None:
self._handle_climate_temp( self._handle_climate_temp(
state, state,
ATTR_TEMPERATURE, ATTR_TEMPERATURE,
@ -518,7 +547,7 @@ class PrometheusMetrics:
float(mode == current_mode) float(mode == current_mode)
) )
def _handle_humidifier(self, state): def _handle_humidifier(self, state: State) -> None:
humidifier_target_humidity_percent = state.attributes.get(ATTR_HUMIDITY) humidifier_target_humidity_percent = state.attributes.get(ATTR_HUMIDITY)
if humidifier_target_humidity_percent: if humidifier_target_humidity_percent:
metric = self._metric( metric = self._metric(
@ -553,7 +582,7 @@ class PrometheusMetrics:
float(mode == current_mode) float(mode == current_mode)
) )
def _handle_sensor(self, state): def _handle_sensor(self, state: State) -> None:
unit = self._unit_string(state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)) unit = self._unit_string(state.attributes.get(ATTR_UNIT_OF_MEASUREMENT))
for metric_handler in self._sensor_metric_handlers: for metric_handler in self._sensor_metric_handlers:
@ -583,12 +612,12 @@ class PrometheusMetrics:
self._battery(state) self._battery(state)
def _sensor_default_metric(self, state, unit): def _sensor_default_metric(self, state: State, unit: str | None) -> str | None:
"""Get default metric.""" """Get default metric."""
return self._default_metric return self._default_metric
@staticmethod @staticmethod
def _sensor_attribute_metric(state, unit): def _sensor_attribute_metric(state: State, unit: str | None) -> str | None:
"""Get metric based on device class attribute.""" """Get metric based on device class attribute."""
metric = state.attributes.get(ATTR_DEVICE_CLASS) metric = state.attributes.get(ATTR_DEVICE_CLASS)
if metric is not None: if metric is not None:
@ -596,25 +625,27 @@ class PrometheusMetrics:
return None return None
@staticmethod @staticmethod
def _sensor_timestamp_metric(state, unit): def _sensor_timestamp_metric(state: State, unit: str | None) -> str | None:
"""Get metric for timestamp sensors, which have no unit of measurement attribute.""" """Get metric for timestamp sensors, which have no unit of measurement attribute."""
metric = state.attributes.get(ATTR_DEVICE_CLASS) metric = state.attributes.get(ATTR_DEVICE_CLASS)
if metric == SensorDeviceClass.TIMESTAMP: if metric == SensorDeviceClass.TIMESTAMP:
return f"sensor_{metric}_seconds" return f"sensor_{metric}_seconds"
return None return None
def _sensor_override_metric(self, state, unit): def _sensor_override_metric(self, state: State, unit: str | None) -> str | None:
"""Get metric from override in configuration.""" """Get metric from override in configuration."""
if self._override_metric: if self._override_metric:
return self._override_metric return self._override_metric
return None return None
def _sensor_override_component_metric(self, state, unit): def _sensor_override_component_metric(
self, state: State, unit: str | None
) -> str | None:
"""Get metric from override in component confioguration.""" """Get metric from override in component confioguration."""
return self._component_config.get(state.entity_id).get(CONF_OVERRIDE_METRIC) return self._component_config.get(state.entity_id).get(CONF_OVERRIDE_METRIC)
@staticmethod @staticmethod
def _sensor_fallback_metric(state, unit): def _sensor_fallback_metric(state: State, unit: str | None) -> str | None:
"""Get metric from fallback logic for compatibility.""" """Get metric from fallback logic for compatibility."""
if unit in (None, ""): if unit in (None, ""):
try: try:
@ -626,10 +657,10 @@ class PrometheusMetrics:
return f"sensor_unit_{unit}" return f"sensor_unit_{unit}"
@staticmethod @staticmethod
def _unit_string(unit): def _unit_string(unit: str | None) -> str | None:
"""Get a formatted string of the unit.""" """Get a formatted string of the unit."""
if unit is None: if unit is None:
return return None
units = { units = {
UnitOfTemperature.CELSIUS: "celsius", UnitOfTemperature.CELSIUS: "celsius",
@ -640,7 +671,7 @@ class PrometheusMetrics:
default = default.lower() default = default.lower()
return units.get(unit, default) return units.get(unit, default)
def _handle_switch(self, state): def _handle_switch(self, state: State) -> None:
metric = self._metric( metric = self._metric(
"switch_state", prometheus_client.Gauge, "State of the switch (0/1)" "switch_state", prometheus_client.Gauge, "State of the switch (0/1)"
) )
@ -653,10 +684,10 @@ class PrometheusMetrics:
self._handle_attributes(state) self._handle_attributes(state)
def _handle_zwave(self, state): def _handle_zwave(self, state: State) -> None:
self._battery(state) self._battery(state)
def _handle_automation(self, state): def _handle_automation(self, state: State) -> None:
metric = self._metric( metric = self._metric(
"automation_triggered_count", "automation_triggered_count",
prometheus_client.Counter, prometheus_client.Counter,
@ -665,7 +696,7 @@ class PrometheusMetrics:
metric.labels(**self._labels(state)).inc() metric.labels(**self._labels(state)).inc()
def _handle_counter(self, state): def _handle_counter(self, state: State) -> None:
metric = self._metric( metric = self._metric(
"counter_value", "counter_value",
prometheus_client.Gauge, prometheus_client.Gauge,
@ -674,7 +705,7 @@ class PrometheusMetrics:
metric.labels(**self._labels(state)).set(self.state_as_number(state)) metric.labels(**self._labels(state)).set(self.state_as_number(state))
def _handle_update(self, state): def _handle_update(self, state: State) -> None:
metric = self._metric( metric = self._metric(
"update_state", "update_state",
prometheus_client.Gauge, prometheus_client.Gauge,
@ -694,7 +725,7 @@ class PrometheusView(HomeAssistantView):
"""Initialize Prometheus view.""" """Initialize Prometheus view."""
self.requires_auth = requires_auth self.requires_auth = requires_auth
async def get(self, request): async def get(self, request: web.Request) -> web.Response:
"""Handle request for Prometheus metrics.""" """Handle request for Prometheus metrics."""
_LOGGER.debug("Received Prometheus metrics request") _LOGGER.debug("Received Prometheus metrics request")

View File

@ -2961,6 +2961,16 @@ disallow_untyped_defs = true
warn_return_any = true warn_return_any = true
warn_unreachable = true warn_unreachable = true
[mypy-homeassistant.components.prometheus.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.proximity.*] [mypy-homeassistant.components.proximity.*]
check_untyped_defs = true check_untyped_defs = true
disallow_incomplete_defs = true disallow_incomplete_defs = true