Improve energy typing (#116034)

This commit is contained in:
Marc Mueller 2024-04-23 17:33:36 +02:00 committed by GitHub
parent 14e19c6d9c
commit 8257af1b22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 14 deletions

View File

@ -3,7 +3,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
from collections.abc import Callable, Mapping
import copy
from dataclasses import dataclass
import logging
@ -167,8 +167,7 @@ class SensorManager:
if adapter.flow_type is None:
self._process_sensor_data(
adapter,
# Opting out of the type complexity because can't get it to work
energy_source, # type: ignore[arg-type]
energy_source,
to_add,
to_remove,
)
@ -177,8 +176,7 @@ class SensorManager:
for flow in energy_source[adapter.flow_type]: # type: ignore[typeddict-item]
self._process_sensor_data(
adapter,
# Opting out of the type complexity because can't get it to work
flow, # type: ignore[arg-type]
flow,
to_add,
to_remove,
)
@ -189,7 +187,7 @@ class SensorManager:
def _process_sensor_data(
self,
adapter: SourceAdapter,
config: dict,
config: Mapping[str, Any],
to_add: list[EnergyCostSensor],
to_remove: dict[tuple[str, str | None, str], EnergyCostSensor],
) -> None:
@ -241,7 +239,7 @@ class EnergyCostSensor(SensorEntity):
def __init__(
self,
adapter: SourceAdapter,
config: dict,
config: Mapping[str, Any],
) -> None:
"""Initialize the sensor."""
super().__init__()
@ -456,7 +454,7 @@ class EnergyCostSensor(SensorEntity):
await super().async_will_remove_from_hass()
@callback
def update_config(self, config: dict) -> None:
def update_config(self, config: Mapping[str, Any]) -> None:
"""Update the config."""
self._config = config

View File

@ -31,7 +31,7 @@ from .data import (
EnergyPreferencesUpdate,
async_get_manager,
)
from .types import EnergyPlatform, GetSolarForecastType
from .types import EnergyPlatform, GetSolarForecastType, SolarForecastType
from .validate import async_validate
EnergyWebSocketCommandHandler = Callable[
@ -203,19 +203,18 @@ async def ws_solar_forecast(
for source in manager.data["energy_sources"]:
if (
source["type"] != "solar"
or source.get("config_entry_solar_forecast") is None
or (solar_forecast := source.get("config_entry_solar_forecast")) is None
):
continue
# typing is not catching the above guard for config_entry_solar_forecast being none
for config_entry in source["config_entry_solar_forecast"]: # type: ignore[union-attr]
config_entries[config_entry] = None
for entry in solar_forecast:
config_entries[entry] = None
if not config_entries:
connection.send_result(msg["id"], {})
return
forecasts = {}
forecasts: dict[str, SolarForecastType] = {}
forecast_platforms = await async_get_energy_platforms(hass)