Make Throttle async aware (#13027)

* Make Throttle async aware

* Lint
This commit is contained in:
Paulus Schoutsen 2018-03-09 19:38:51 -08:00
parent 8f807a3006
commit 6ffc53b290
11 changed files with 54 additions and 48 deletions

View File

@ -423,19 +423,17 @@ class BluesoundPlayer(MediaPlayerDevice):
for player in self._hass.data[DATA_BLUESOUND]: for player in self._hass.data[DATA_BLUESOUND]:
yield from player.force_update_sync_status() yield from player.force_update_sync_status()
@asyncio.coroutine
@Throttle(SYNC_STATUS_INTERVAL) @Throttle(SYNC_STATUS_INTERVAL)
def async_update_sync_status(self, on_updated_cb=None, async def async_update_sync_status(self, on_updated_cb=None,
raise_timeout=False): raise_timeout=False):
"""Update sync status.""" """Update sync status."""
yield from self.force_update_sync_status( await self.force_update_sync_status(
on_updated_cb, raise_timeout=False) on_updated_cb, raise_timeout=False)
@asyncio.coroutine
@Throttle(UPDATE_CAPTURE_INTERVAL) @Throttle(UPDATE_CAPTURE_INTERVAL)
def async_update_captures(self): async def async_update_captures(self):
"""Update Capture sources.""" """Update Capture sources."""
resp = yield from self.send_bluesound_command( resp = await self.send_bluesound_command(
'RadioBrowse?service=Capture') 'RadioBrowse?service=Capture')
if not resp: if not resp:
return return
@ -459,11 +457,10 @@ class BluesoundPlayer(MediaPlayerDevice):
return self._capture_items return self._capture_items
@asyncio.coroutine
@Throttle(UPDATE_PRESETS_INTERVAL) @Throttle(UPDATE_PRESETS_INTERVAL)
def async_update_presets(self): async def async_update_presets(self):
"""Update Presets.""" """Update Presets."""
resp = yield from self.send_bluesound_command('Presets') resp = await self.send_bluesound_command('Presets')
if not resp: if not resp:
return return
self._preset_items = [] self._preset_items = []
@ -488,11 +485,10 @@ class BluesoundPlayer(MediaPlayerDevice):
return self._preset_items return self._preset_items
@asyncio.coroutine
@Throttle(UPDATE_SERVICES_INTERVAL) @Throttle(UPDATE_SERVICES_INTERVAL)
def async_update_services(self): async def async_update_services(self):
"""Update Services.""" """Update Services."""
resp = yield from self.send_bluesound_command('Services') resp = await self.send_bluesound_command('Services')
if not resp: if not resp:
return return
self._services_items = [] self._services_items = []

View File

@ -253,8 +253,7 @@ class Volumio(MediaPlayerDevice):
return self.send_volumio_msg('commands', return self.send_volumio_msg('commands',
params={'cmd': 'clearQueue'}) params={'cmd': 'clearQueue'})
@asyncio.coroutine
@Throttle(PLAYLIST_UPDATE_INTERVAL) @Throttle(PLAYLIST_UPDATE_INTERVAL)
def _async_update_playlists(self, **kwargs): async def _async_update_playlists(self, **kwargs):
"""Update available Volumio playlists.""" """Update available Volumio playlists."""
self._playlists = yield from self.send_volumio_msg('listplaylists') self._playlists = await self.send_volumio_msg('listplaylists')

View File

@ -157,13 +157,12 @@ class FidoData(object):
REQUESTS_TIMEOUT, httpsession) REQUESTS_TIMEOUT, httpsession)
self.data = {} self.data = {}
@asyncio.coroutine
@Throttle(MIN_TIME_BETWEEN_UPDATES) @Throttle(MIN_TIME_BETWEEN_UPDATES)
def async_update(self): async def async_update(self):
"""Get the latest data from Fido.""" """Get the latest data from Fido."""
from pyfido.client import PyFidoError from pyfido.client import PyFidoError
try: try:
yield from self.client.fetch_data() await self.client.fetch_data()
except PyFidoError as exp: except PyFidoError as exp:
_LOGGER.error("Error on receive last Fido data: %s", exp) _LOGGER.error("Error on receive last Fido data: %s", exp)
return False return False

View File

@ -182,13 +182,12 @@ class HydroquebecData(object):
return self.client.get_contracts() return self.client.get_contracts()
return [] return []
@asyncio.coroutine
@Throttle(MIN_TIME_BETWEEN_UPDATES) @Throttle(MIN_TIME_BETWEEN_UPDATES)
def _fetch_data(self): async def _fetch_data(self):
"""Fetch latest data from HydroQuebec.""" """Fetch latest data from HydroQuebec."""
from pyhydroquebec.client import PyHydroQuebecError from pyhydroquebec.client import PyHydroQuebecError
try: try:
yield from self.client.fetch_data() await self.client.fetch_data()
except PyHydroQuebecError as exp: except PyHydroQuebecError as exp:
_LOGGER.error("Error on receive last Hydroquebec data: %s", exp) _LOGGER.error("Error on receive last Hydroquebec data: %s", exp)
return False return False

View File

@ -133,13 +133,9 @@ class LuftdatenSensor(Entity):
except KeyError: except KeyError:
return return
@asyncio.coroutine async def async_update(self):
def async_update(self):
"""Get the latest data from luftdaten.info and update the state.""" """Get the latest data from luftdaten.info and update the state."""
try: await self.luftdaten.async_update()
yield from self.luftdaten.async_update()
except TypeError:
pass
class LuftdatenData(object): class LuftdatenData(object):
@ -150,12 +146,11 @@ class LuftdatenData(object):
self.data = data self.data = data
@Throttle(MIN_TIME_BETWEEN_UPDATES) @Throttle(MIN_TIME_BETWEEN_UPDATES)
@asyncio.coroutine async def async_update(self):
def async_update(self):
"""Get the latest data from luftdaten.info.""" """Get the latest data from luftdaten.info."""
from luftdaten.exceptions import LuftdatenError from luftdaten.exceptions import LuftdatenError
try: try:
yield from self.data.async_get_data() await self.data.async_get_data()
except LuftdatenError: except LuftdatenError:
_LOGGER.error("Unable to retrieve data from luftdaten.info") _LOGGER.error("Unable to retrieve data from luftdaten.info")

View File

@ -75,15 +75,14 @@ def setup_sabnzbd(base_url, apikey, name, config,
for variable in monitored]) for variable in monitored])
@asyncio.coroutine
@Throttle(MIN_TIME_BETWEEN_UPDATES) @Throttle(MIN_TIME_BETWEEN_UPDATES)
def async_update_queue(sab_api): async def async_update_queue(sab_api):
""" """
Throttled function to update SABnzbd queue. Throttled function to update SABnzbd queue.
This ensures that the queue info only gets updated once for all sensors This ensures that the queue info only gets updated once for all sensors
""" """
yield from sab_api.refresh_queue() await sab_api.refresh_queue()
def request_configuration(host, name, hass, config, async_add_devices, def request_configuration(host, name, hass, config, async_add_devices,

View File

@ -140,21 +140,20 @@ class StartcaData(object):
""" """
return float(value) * 10 ** -9 return float(value) * 10 ** -9
@asyncio.coroutine
@Throttle(MIN_TIME_BETWEEN_UPDATES) @Throttle(MIN_TIME_BETWEEN_UPDATES)
def async_update(self): async def async_update(self):
"""Get the Start.ca bandwidth data from the web service.""" """Get the Start.ca bandwidth data from the web service."""
import xmltodict import xmltodict
_LOGGER.debug("Updating Start.ca usage data") _LOGGER.debug("Updating Start.ca usage data")
url = 'https://www.start.ca/support/usage/api?key=' + \ url = 'https://www.start.ca/support/usage/api?key=' + \
self.api_key self.api_key
with async_timeout.timeout(REQUEST_TIMEOUT, loop=self.loop): with async_timeout.timeout(REQUEST_TIMEOUT, loop=self.loop):
req = yield from self.websession.get(url) req = await self.websession.get(url)
if req.status != 200: if req.status != 200:
_LOGGER.error("Request failed with status: %u", req.status) _LOGGER.error("Request failed with status: %u", req.status)
return False return False
data = yield from req.text() data = await req.text()
try: try:
xml_data = xmltodict.parse(data) xml_data = xmltodict.parse(data)
except ExpatError: except ExpatError:

View File

@ -132,22 +132,21 @@ class TekSavvyData(object):
self.data = {"limit": self.bandwidth_cap} if self.bandwidth_cap > 0 \ self.data = {"limit": self.bandwidth_cap} if self.bandwidth_cap > 0 \
else {"limit": float('inf')} else {"limit": float('inf')}
@asyncio.coroutine
@Throttle(MIN_TIME_BETWEEN_UPDATES) @Throttle(MIN_TIME_BETWEEN_UPDATES)
def async_update(self): async def async_update(self):
"""Get the TekSavvy bandwidth data from the web service.""" """Get the TekSavvy bandwidth data from the web service."""
headers = {"TekSavvy-APIKey": self.api_key} headers = {"TekSavvy-APIKey": self.api_key}
_LOGGER.debug("Updating TekSavvy data") _LOGGER.debug("Updating TekSavvy data")
url = "https://api.teksavvy.com/"\ url = "https://api.teksavvy.com/"\
"web/Usage/UsageSummaryRecords?$filter=IsCurrent%20eq%20true" "web/Usage/UsageSummaryRecords?$filter=IsCurrent%20eq%20true"
with async_timeout.timeout(REQUEST_TIMEOUT, loop=self.loop): with async_timeout.timeout(REQUEST_TIMEOUT, loop=self.loop):
req = yield from self.websession.get(url, headers=headers) req = await self.websession.get(url, headers=headers)
if req.status != 200: if req.status != 200:
_LOGGER.error("Request failed with status: %u", req.status) _LOGGER.error("Request failed with status: %u", req.status)
return False return False
try: try:
data = yield from req.json() data = await req.json()
for (api, ha_name) in API_HA_MAP: for (api, ha_name) in API_HA_MAP:
self.data[ha_name] = float(data["value"][0][api]) self.data[ha_name] = float(data["value"][0][api])
on_peak_download = self.data["onpeak_download"] on_peak_download = self.data["onpeak_download"]

View File

@ -777,14 +777,13 @@ class WUndergroundData(object):
return url + '.json' return url + '.json'
@asyncio.coroutine
@Throttle(MIN_TIME_BETWEEN_UPDATES) @Throttle(MIN_TIME_BETWEEN_UPDATES)
def async_update(self): async def async_update(self):
"""Get the latest data from WUnderground.""" """Get the latest data from WUnderground."""
try: try:
with async_timeout.timeout(10, loop=self._hass.loop): with async_timeout.timeout(10, loop=self._hass.loop):
response = yield from self._session.get(self._build_url()) response = await self._session.get(self._build_url())
result = yield from response.json() result = await response.json()
if "error" in result['response']: if "error" in result['response']:
raise ValueError(result['response']["error"]["description"]) raise ValueError(result['response']["error"]["description"])
self.data = result self.data = result

View File

@ -1,4 +1,5 @@
"""Helper methods for various modules.""" """Helper methods for various modules."""
import asyncio
from collections.abc import MutableSet from collections.abc import MutableSet
from itertools import chain from itertools import chain
import threading import threading
@ -276,6 +277,16 @@ class Throttle(object):
is_func = (not hasattr(method, '__self__') and is_func = (not hasattr(method, '__self__') and
'.' not in method.__qualname__.split('.<locals>.')[-1]) '.' not in method.__qualname__.split('.<locals>.')[-1])
# Make sure we return a coroutine if the method is async.
if asyncio.iscoroutinefunction(method):
async def throttled_value():
"""Stand-in function for when real func is being throttled."""
return None
else:
def throttled_value():
"""Stand-in function for when real func is being throttled."""
return None
@wraps(method) @wraps(method)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
"""Wrap that allows wrapped to be called only once per min_time. """Wrap that allows wrapped to be called only once per min_time.
@ -298,7 +309,7 @@ class Throttle(object):
throttle = host._throttle[id(self)] throttle = host._throttle[id(self)]
if not throttle[0].acquire(False): if not throttle[0].acquire(False):
return None return throttled_value()
# Check if method is never called or no_throttle is given # Check if method is never called or no_throttle is given
force = kwargs.pop('no_throttle', False) or not throttle[1] force = kwargs.pop('no_throttle', False) or not throttle[1]
@ -309,7 +320,7 @@ class Throttle(object):
throttle[1] = utcnow() throttle[1] = utcnow()
return result return result
return None return throttled_value()
finally: finally:
throttle[0].release() throttle[0].release()

View File

@ -280,3 +280,14 @@ class TestUtil(unittest.TestCase):
mock_random.SystemRandom.return_value = generator mock_random.SystemRandom.return_value = generator
assert util.get_random_string(length=3) == 'ABC' assert util.get_random_string(length=3) == 'ABC'
async def test_throttle_async():
"""Test Throttle decorator with async method."""
@util.Throttle(timedelta(seconds=2))
async def test_method():
"""Only first call should return a value."""
return True
assert (await test_method()) is True
assert (await test_method()) is None