diff --git a/ha_test/test_util.py b/ha_test/test_util.py index ee48ef3a784..c7f0b848ab2 100644 --- a/ha_test/test_util.py +++ b/ha_test/test_util.py @@ -6,7 +6,8 @@ Tests Home Assistant util methods. """ # pylint: disable=too-many-public-methods import unittest -from datetime import datetime +import time +from datetime import datetime, timedelta import homeassistant.util as util @@ -190,3 +191,33 @@ class TestUtil(unittest.TestCase): set1.update([1, 2], [5, 6]) self.assertEqual([2, 3, 1, 5, 6], set1) + + def test_add_cooldown(self): + """ Test the add cooldown decorator. """ + calls = [] + + @util.AddCooldown(timedelta(milliseconds=500)) + def test_cooldown(): + calls.append(1) + + self.assertEqual(0, len(calls)) + + test_cooldown() + + self.assertEqual(1, len(calls)) + + test_cooldown() + + self.assertEqual(1, len(calls)) + + time.sleep(.3) + + test_cooldown() + + self.assertEqual(1, len(calls)) + + time.sleep(.2) + + test_cooldown() + + self.assertEqual(2, len(calls)) diff --git a/homeassistant/components/device_tracker/luci.py b/homeassistant/components/device_tracker/luci.py index 89d50f0239f..5409babacb9 100644 --- a/homeassistant/components/device_tracker/luci.py +++ b/homeassistant/components/device_tracker/luci.py @@ -1,7 +1,7 @@ """ Supports scanning a OpenWRT router. """ import logging import json -from datetime import datetime, timedelta +from datetime import timedelta import re import threading import requests @@ -52,7 +52,6 @@ class LuciDeviceScanner(object): self.lock = threading.Lock() - self.date_updated = None self.last_results = {} self.token = _get_token(host, username, password) @@ -88,29 +87,25 @@ class LuciDeviceScanner(object): return return self.mac2name.get(device, None) + @util.AddCooldown(MIN_TIME_BETWEEN_SCANS) def _update_info(self): """ Ensures the information from the Luci router is up to date. Returns boolean if scanning successful. """ if not self.success_init: return False + with self.lock: - # if date_updated is None or the date is too old we scan - # for new data - if not self.date_updated or \ - datetime.now() - self.date_updated > MIN_TIME_BETWEEN_SCANS: + _LOGGER.info("Checking ARP") - _LOGGER.info("Checking ARP") + url = 'http://{}/cgi-bin/luci/rpc/sys'.format(self.host) + result = _req_json_rpc(url, 'net.arptable', + params={'auth': self.token}) + if result: + self.last_results = [x['HW address'] for x in result] - url = 'http://{}/cgi-bin/luci/rpc/sys'.format(self.host) - result = _req_json_rpc(url, 'net.arptable', - params={'auth': self.token}) - if result: - self.last_results = [x['HW address'] for x in result] - self.date_updated = datetime.now() - return True - return False + return True - return True + return False def _req_json_rpc(url, method, *args, **kwargs): diff --git a/homeassistant/components/device_tracker/netgear.py b/homeassistant/components/device_tracker/netgear.py index 23eda17fff8..98485afefd7 100644 --- a/homeassistant/components/device_tracker/netgear.py +++ b/homeassistant/components/device_tracker/netgear.py @@ -1,6 +1,6 @@ """ Supports scanning a Netgear router. """ import logging -from datetime import datetime, timedelta +from datetime import timedelta import threading import homeassistant as ha @@ -34,7 +34,6 @@ class NetgearDeviceScanner(object): host = config[ha.CONF_HOST] username, password = config[ha.CONF_USERNAME], config[ha.CONF_PASSWORD] - self.date_updated = None self.last_results = [] try: @@ -75,10 +74,6 @@ class NetgearDeviceScanner(object): def get_device_name(self, mac): """ Returns the name of the given device or None if we don't know. """ - # Make sure there are results - if not self.date_updated: - self._update_info() - filter_named = [device.name for device in self.last_results if device.mac == mac] @@ -87,6 +82,7 @@ class NetgearDeviceScanner(object): else: return None + @util.AddCooldown(MIN_TIME_BETWEEN_SCANS) def _update_info(self): """ Retrieves latest information from the Netgear router. Returns boolean if scanning successful. """ @@ -94,18 +90,6 @@ class NetgearDeviceScanner(object): return with self.lock: - # if date_updated is None or the date is too old we scan for - # new data - if not self.date_updated or \ - datetime.now() - self.date_updated > MIN_TIME_BETWEEN_SCANS: + _LOGGER.info("Scanning") - _LOGGER.info("Scanning") - - self.last_results = self._api.get_attached_devices() - - self.date_updated = datetime.now() - - return - - else: - return + self.last_results = self._api.get_attached_devices() diff --git a/homeassistant/components/device_tracker/tomato.py b/homeassistant/components/device_tracker/tomato.py index 748ad53f534..0a5eb1c4fa5 100644 --- a/homeassistant/components/device_tracker/tomato.py +++ b/homeassistant/components/device_tracker/tomato.py @@ -1,7 +1,7 @@ """ Supports scanning a Tomato router. """ import logging import json -from datetime import datetime, timedelta +from datetime import timedelta import re import threading @@ -55,7 +55,6 @@ class TomatoDeviceScanner(object): self.logger = logging.getLogger("{}.{}".format(__name__, "Tomato")) self.lock = threading.Lock() - self.date_updated = None self.last_results = {"wldev": [], "dhcpd_lease": []} self.success_init = self._update_tomato_info() @@ -71,10 +70,6 @@ class TomatoDeviceScanner(object): def get_device_name(self, device): """ Returns the name of the given device or None if we don't know. """ - # Make sure there are results - if not self.date_updated: - self._update_tomato_info() - filter_named = [item[0] for item in self.last_results['dhcpd_lease'] if item[2] == device] @@ -83,16 +78,12 @@ class TomatoDeviceScanner(object): else: return filter_named[0] + @util.AddCooldown(MIN_TIME_BETWEEN_SCANS) def _update_tomato_info(self): """ Ensures the information from the Tomato router is up to date. Returns boolean if scanning successful. """ - self.lock.acquire() - - # if date_updated is None or the date is too old we scan for new data - if not self.date_updated or \ - datetime.now() - self.date_updated > MIN_TIME_BETWEEN_SCANS: - + with self.lock: self.logger.info("Scanning") try: @@ -111,8 +102,6 @@ class TomatoDeviceScanner(object): self.last_results[param] = \ json.loads(value.replace("'", '"')) - self.date_updated = datetime.now() - return True elif response.status_code == 401: @@ -146,13 +135,3 @@ class TomatoDeviceScanner(object): "Failed to parse response from router") return False - - finally: - self.lock.release() - - else: - # We acquired the lock before the IF check, - # release it before we return True - self.lock.release() - - return True diff --git a/homeassistant/util.py b/homeassistant/util.py index a4b812803d4..40cb463b37e 100644 --- a/homeassistant/util.py +++ b/homeassistant/util.py @@ -12,6 +12,7 @@ import datetime import re import enum import socket +from functools import wraps RE_SANITIZE_FILENAME = re.compile(r'(~|\.\.|/|\\)') RE_SANITIZE_PATH = re.compile(r'(~|\.(\.)+)') @@ -273,6 +274,45 @@ def validate_config(config, items, logger): return not errors_found +class AddCooldown(object): + """ + A method decorator to add a cooldown to a method. + + If you set a cooldown of 5 seconds. Then if you call a method twice the + underlaying method will not be called if the second call was within + 5 seconds of the first. None will be returned instead. + + Makes a last_call attribute available on the wrapped method. + """ + # pylint: disable=too-few-public-methods + + def __init__(self, min_time): + self.min_time = min_time + + def __call__(self, method): + lock = threading.Lock() + + @wraps(method) + def wrapper(*args, **kwargs): + """ + Wrapper that allows wrapped to be called only once per min_time. + """ + with lock: + now = datetime.datetime.now() + last_call = wrapper.last_call + + if last_call is None or now - last_call > self.min_time: + result = method(*args, **kwargs) + wrapper.last_call = now + return result + else: + return None + + wrapper.last_call = None + + return wrapper + + # Reason why I decided to roll my own ThreadPool instead of using # multiprocessing.dummy.pool or even better, use multiprocessing.pool and # not be hurt by the GIL in the cpython interpreter: