Use dataclasses in netatmo data handler (#52537)

* Switch to using dataclasses

* Clean up

* Update homeassistant/components/netatmo/data_handler.py
This commit is contained in:
Tobias Sauerwein 2021-07-05 13:05:18 +02:00 committed by GitHub
parent 4cac85e3f5
commit f9c7137d02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 22 deletions

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections import deque from collections import deque
from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
from itertools import islice from itertools import islice
import logging import logging
@ -34,8 +35,6 @@ HOMEDATA_DATA_CLASS_NAME = "AsyncHomeData"
HOMESTATUS_DATA_CLASS_NAME = "AsyncHomeStatus" HOMESTATUS_DATA_CLASS_NAME = "AsyncHomeStatus"
PUBLICDATA_DATA_CLASS_NAME = "AsyncPublicData" PUBLICDATA_DATA_CLASS_NAME = "AsyncPublicData"
NEXT_SCAN = "next_scan"
DATA_CLASSES = { DATA_CLASSES = {
WEATHERSTATION_DATA_CLASS_NAME: pyatmo.AsyncWeatherStationData, WEATHERSTATION_DATA_CLASS_NAME: pyatmo.AsyncWeatherStationData,
HOMECOACH_DATA_CLASS_NAME: pyatmo.AsyncHomeCoachData, HOMECOACH_DATA_CLASS_NAME: pyatmo.AsyncHomeCoachData,
@ -57,6 +56,16 @@ DEFAULT_INTERVALS = {
SCAN_INTERVAL = 60 SCAN_INTERVAL = 60
@dataclass
class NetatmoDataClass:
"""Class for keeping track of Netatmo data class metadata."""
name: str
interval: int
next_scan: float
subscriptions: list[CALLBACK_TYPE]
class NetatmoDataHandler: class NetatmoDataHandler:
"""Manages the Netatmo data handling.""" """Manages the Netatmo data handling."""
@ -93,12 +102,12 @@ class NetatmoDataHandler:
to minimize the calls on the api service. to minimize the calls on the api service.
""" """
for data_class in islice(self._queue, 0, BATCH_SIZE): for data_class in islice(self._queue, 0, BATCH_SIZE):
if data_class[NEXT_SCAN] > time(): if data_class.next_scan > time():
continue continue
if data_class_name := data_class["name"]: if data_class_name := data_class.name:
self.data_classes[data_class_name][NEXT_SCAN] = ( self.data_classes[data_class_name].next_scan = (
time() + data_class["interval"] time() + data_class.interval
) )
await self.async_fetch_data(data_class_name) await self.async_fetch_data(data_class_name)
@ -108,7 +117,7 @@ class NetatmoDataHandler:
@callback @callback
def async_force_update(self, data_class_entry): def async_force_update(self, data_class_entry):
"""Prioritize data retrieval for given data class entry.""" """Prioritize data retrieval for given data class entry."""
self.data_classes[data_class_entry][NEXT_SCAN] = time() self.data_classes[data_class_entry].next_scan = time()
self._queue.rotate(-(self._queue.index(self.data_classes[data_class_entry]))) self._queue.rotate(-(self._queue.index(self.data_classes[data_class_entry])))
async def async_cleanup(self): async def async_cleanup(self):
@ -149,7 +158,7 @@ class NetatmoDataHandler:
_LOGGER.debug(err) _LOGGER.debug(err)
return return
for update_callback in self.data_classes[data_class_entry]["subscriptions"]: for update_callback in self.data_classes[data_class_entry].subscriptions:
if update_callback: if update_callback:
update_callback() update_callback()
@ -158,21 +167,18 @@ class NetatmoDataHandler:
): ):
"""Register data class.""" """Register data class."""
if data_class_entry in self.data_classes: if data_class_entry in self.data_classes:
if ( if update_callback not in self.data_classes[data_class_entry].subscriptions:
update_callback self.data_classes[data_class_entry].subscriptions.append(
not in self.data_classes[data_class_entry]["subscriptions"]
):
self.data_classes[data_class_entry]["subscriptions"].append(
update_callback update_callback
) )
return return
self.data_classes[data_class_entry] = { self.data_classes[data_class_entry] = NetatmoDataClass(
"name": data_class_entry, name=data_class_entry,
"interval": DEFAULT_INTERVALS[data_class_name], interval=DEFAULT_INTERVALS[data_class_name],
NEXT_SCAN: time() + DEFAULT_INTERVALS[data_class_name], next_scan=time() + DEFAULT_INTERVALS[data_class_name],
"subscriptions": [update_callback], subscriptions=[update_callback],
} )
self.data[data_class_entry] = DATA_CLASSES[data_class_name]( self.data[data_class_entry] = DATA_CLASSES[data_class_name](
self._auth, **kwargs self._auth, **kwargs
@ -185,9 +191,9 @@ class NetatmoDataHandler:
async def unregister_data_class(self, data_class_entry, update_callback): async def unregister_data_class(self, data_class_entry, update_callback):
"""Unregister data class.""" """Unregister data class."""
self.data_classes[data_class_entry]["subscriptions"].remove(update_callback) self.data_classes[data_class_entry].subscriptions.remove(update_callback)
if not self.data_classes[data_class_entry].get("subscriptions"): if not self.data_classes[data_class_entry].subscriptions:
self._queue.remove(self.data_classes[data_class_entry]) self._queue.remove(self.data_classes[data_class_entry])
self.data_classes.pop(data_class_entry) self.data_classes.pop(data_class_entry)
self.data.pop(data_class_entry) self.data.pop(data_class_entry)

View File

@ -61,7 +61,7 @@ class NetatmoBase(Entity):
data_class["name"], signal_name, self.async_update_callback data_class["name"], signal_name, self.async_update_callback
) )
for sub in self.data_handler.data_classes[signal_name].get("subscriptions"): for sub in self.data_handler.data_classes[signal_name].subscriptions:
if sub is None: if sub is None:
await self.data_handler.unregister_data_class(signal_name, None) await self.data_handler.unregister_data_class(signal_name, None)