diff --git a/homeassistant/components/media_player/bluesound.py b/homeassistant/components/media_player/bluesound.py index d308b94e64c..a07e577c969 100644 --- a/homeassistant/components/media_player/bluesound.py +++ b/homeassistant/components/media_player/bluesound.py @@ -423,19 +423,17 @@ class BluesoundPlayer(MediaPlayerDevice): for player in self._hass.data[DATA_BLUESOUND]: yield from player.force_update_sync_status() - @asyncio.coroutine @Throttle(SYNC_STATUS_INTERVAL) - def async_update_sync_status(self, on_updated_cb=None, - raise_timeout=False): + async def async_update_sync_status(self, on_updated_cb=None, + raise_timeout=False): """Update sync status.""" - yield from self.force_update_sync_status( + await self.force_update_sync_status( on_updated_cb, raise_timeout=False) - @asyncio.coroutine @Throttle(UPDATE_CAPTURE_INTERVAL) - def async_update_captures(self): + async def async_update_captures(self): """Update Capture sources.""" - resp = yield from self.send_bluesound_command( + resp = await self.send_bluesound_command( 'RadioBrowse?service=Capture') if not resp: return @@ -459,11 +457,10 @@ class BluesoundPlayer(MediaPlayerDevice): return self._capture_items - @asyncio.coroutine @Throttle(UPDATE_PRESETS_INTERVAL) - def async_update_presets(self): + async def async_update_presets(self): """Update Presets.""" - resp = yield from self.send_bluesound_command('Presets') + resp = await self.send_bluesound_command('Presets') if not resp: return self._preset_items = [] @@ -488,11 +485,10 @@ class BluesoundPlayer(MediaPlayerDevice): return self._preset_items - @asyncio.coroutine @Throttle(UPDATE_SERVICES_INTERVAL) - def async_update_services(self): + async def async_update_services(self): """Update Services.""" - resp = yield from self.send_bluesound_command('Services') + resp = await self.send_bluesound_command('Services') if not resp: return self._services_items = [] diff --git a/homeassistant/components/media_player/volumio.py b/homeassistant/components/media_player/volumio.py index 84b957533fe..0a940c0aa9d 100644 --- a/homeassistant/components/media_player/volumio.py +++ b/homeassistant/components/media_player/volumio.py @@ -253,8 +253,7 @@ class Volumio(MediaPlayerDevice): return self.send_volumio_msg('commands', params={'cmd': 'clearQueue'}) - @asyncio.coroutine @Throttle(PLAYLIST_UPDATE_INTERVAL) - def _async_update_playlists(self, **kwargs): + async def _async_update_playlists(self, **kwargs): """Update available Volumio playlists.""" - self._playlists = yield from self.send_volumio_msg('listplaylists') + self._playlists = await self.send_volumio_msg('listplaylists') diff --git a/homeassistant/components/sensor/fido.py b/homeassistant/components/sensor/fido.py index 4fc79745b99..25a104bf259 100644 --- a/homeassistant/components/sensor/fido.py +++ b/homeassistant/components/sensor/fido.py @@ -157,13 +157,12 @@ class FidoData(object): REQUESTS_TIMEOUT, httpsession) self.data = {} - @asyncio.coroutine @Throttle(MIN_TIME_BETWEEN_UPDATES) - def async_update(self): + async def async_update(self): """Get the latest data from Fido.""" from pyfido.client import PyFidoError try: - yield from self.client.fetch_data() + await self.client.fetch_data() except PyFidoError as exp: _LOGGER.error("Error on receive last Fido data: %s", exp) return False diff --git a/homeassistant/components/sensor/hydroquebec.py b/homeassistant/components/sensor/hydroquebec.py index e10abc14ff1..3678ac9268f 100644 --- a/homeassistant/components/sensor/hydroquebec.py +++ b/homeassistant/components/sensor/hydroquebec.py @@ -182,13 +182,12 @@ class HydroquebecData(object): return self.client.get_contracts() return [] - @asyncio.coroutine @Throttle(MIN_TIME_BETWEEN_UPDATES) - def _fetch_data(self): + async def _fetch_data(self): """Fetch latest data from HydroQuebec.""" from pyhydroquebec.client import PyHydroQuebecError try: - yield from self.client.fetch_data() + await self.client.fetch_data() except PyHydroQuebecError as exp: _LOGGER.error("Error on receive last Hydroquebec data: %s", exp) return False diff --git a/homeassistant/components/sensor/luftdaten.py b/homeassistant/components/sensor/luftdaten.py index 72ee8a7ce93..c5e0b12b0e0 100644 --- a/homeassistant/components/sensor/luftdaten.py +++ b/homeassistant/components/sensor/luftdaten.py @@ -133,13 +133,9 @@ class LuftdatenSensor(Entity): except KeyError: return - @asyncio.coroutine - def async_update(self): + async def async_update(self): """Get the latest data from luftdaten.info and update the state.""" - try: - yield from self.luftdaten.async_update() - except TypeError: - pass + await self.luftdaten.async_update() class LuftdatenData(object): @@ -150,12 +146,11 @@ class LuftdatenData(object): self.data = data @Throttle(MIN_TIME_BETWEEN_UPDATES) - @asyncio.coroutine - def async_update(self): + async def async_update(self): """Get the latest data from luftdaten.info.""" from luftdaten.exceptions import LuftdatenError try: - yield from self.data.async_get_data() + await self.data.async_get_data() except LuftdatenError: _LOGGER.error("Unable to retrieve data from luftdaten.info") diff --git a/homeassistant/components/sensor/sabnzbd.py b/homeassistant/components/sensor/sabnzbd.py index 632e1ed5c1d..c5dd09e0ccc 100644 --- a/homeassistant/components/sensor/sabnzbd.py +++ b/homeassistant/components/sensor/sabnzbd.py @@ -75,15 +75,14 @@ def setup_sabnzbd(base_url, apikey, name, config, for variable in monitored]) -@asyncio.coroutine @Throttle(MIN_TIME_BETWEEN_UPDATES) -def async_update_queue(sab_api): +async def async_update_queue(sab_api): """ Throttled function to update SABnzbd queue. This ensures that the queue info only gets updated once for all sensors """ - yield from sab_api.refresh_queue() + await sab_api.refresh_queue() def request_configuration(host, name, hass, config, async_add_devices, diff --git a/homeassistant/components/sensor/startca.py b/homeassistant/components/sensor/startca.py index a5908812b6c..aefbc2d4626 100644 --- a/homeassistant/components/sensor/startca.py +++ b/homeassistant/components/sensor/startca.py @@ -140,21 +140,20 @@ class StartcaData(object): """ return float(value) * 10 ** -9 - @asyncio.coroutine @Throttle(MIN_TIME_BETWEEN_UPDATES) - def async_update(self): + async def async_update(self): """Get the Start.ca bandwidth data from the web service.""" import xmltodict _LOGGER.debug("Updating Start.ca usage data") url = 'https://www.start.ca/support/usage/api?key=' + \ self.api_key with async_timeout.timeout(REQUEST_TIMEOUT, loop=self.loop): - req = yield from self.websession.get(url) + req = await self.websession.get(url) if req.status != 200: _LOGGER.error("Request failed with status: %u", req.status) return False - data = yield from req.text() + data = await req.text() try: xml_data = xmltodict.parse(data) except ExpatError: diff --git a/homeassistant/components/sensor/teksavvy.py b/homeassistant/components/sensor/teksavvy.py index 9c4263422ff..0bf1ef4caff 100644 --- a/homeassistant/components/sensor/teksavvy.py +++ b/homeassistant/components/sensor/teksavvy.py @@ -132,22 +132,21 @@ class TekSavvyData(object): self.data = {"limit": self.bandwidth_cap} if self.bandwidth_cap > 0 \ else {"limit": float('inf')} - @asyncio.coroutine @Throttle(MIN_TIME_BETWEEN_UPDATES) - def async_update(self): + async def async_update(self): """Get the TekSavvy bandwidth data from the web service.""" headers = {"TekSavvy-APIKey": self.api_key} _LOGGER.debug("Updating TekSavvy data") url = "https://api.teksavvy.com/"\ "web/Usage/UsageSummaryRecords?$filter=IsCurrent%20eq%20true" with async_timeout.timeout(REQUEST_TIMEOUT, loop=self.loop): - req = yield from self.websession.get(url, headers=headers) + req = await self.websession.get(url, headers=headers) if req.status != 200: _LOGGER.error("Request failed with status: %u", req.status) return False try: - data = yield from req.json() + data = await req.json() for (api, ha_name) in API_HA_MAP: self.data[ha_name] = float(data["value"][0][api]) on_peak_download = self.data["onpeak_download"] diff --git a/homeassistant/components/sensor/wunderground.py b/homeassistant/components/sensor/wunderground.py index edcc1c92bf9..0375bb1344c 100644 --- a/homeassistant/components/sensor/wunderground.py +++ b/homeassistant/components/sensor/wunderground.py @@ -777,14 +777,13 @@ class WUndergroundData(object): return url + '.json' - @asyncio.coroutine @Throttle(MIN_TIME_BETWEEN_UPDATES) - def async_update(self): + async def async_update(self): """Get the latest data from WUnderground.""" try: with async_timeout.timeout(10, loop=self._hass.loop): - response = yield from self._session.get(self._build_url()) - result = yield from response.json() + response = await self._session.get(self._build_url()) + result = await response.json() if "error" in result['response']: raise ValueError(result['response']["error"]["description"]) self.data = result diff --git a/homeassistant/util/__init__.py b/homeassistant/util/__init__.py index 75721a37466..a869251dc3c 100644 --- a/homeassistant/util/__init__.py +++ b/homeassistant/util/__init__.py @@ -1,4 +1,5 @@ """Helper methods for various modules.""" +import asyncio from collections.abc import MutableSet from itertools import chain import threading @@ -276,6 +277,16 @@ class Throttle(object): is_func = (not hasattr(method, '__self__') and '.' not in method.__qualname__.split('..')[-1]) + # Make sure we return a coroutine if the method is async. + if asyncio.iscoroutinefunction(method): + async def throttled_value(): + """Stand-in function for when real func is being throttled.""" + return None + else: + def throttled_value(): + """Stand-in function for when real func is being throttled.""" + return None + @wraps(method) def wrapper(*args, **kwargs): """Wrap that allows wrapped to be called only once per min_time. @@ -298,7 +309,7 @@ class Throttle(object): throttle = host._throttle[id(self)] if not throttle[0].acquire(False): - return None + return throttled_value() # Check if method is never called or no_throttle is given force = kwargs.pop('no_throttle', False) or not throttle[1] @@ -309,7 +320,7 @@ class Throttle(object): throttle[1] = utcnow() return result - return None + return throttled_value() finally: throttle[0].release() diff --git a/tests/util/test_init.py b/tests/util/test_init.py index 2902cb62517..5493843c246 100644 --- a/tests/util/test_init.py +++ b/tests/util/test_init.py @@ -280,3 +280,14 @@ class TestUtil(unittest.TestCase): mock_random.SystemRandom.return_value = generator assert util.get_random_string(length=3) == 'ABC' + + +async def test_throttle_async(): + """Test Throttle decorator with async method.""" + @util.Throttle(timedelta(seconds=2)) + async def test_method(): + """Only first call should return a value.""" + return True + + assert (await test_method()) is True + assert (await test_method()) is None