Added initial version of AddCooldown decorator

This commit is contained in:
Paulus Schoutsen 2014-12-04 01:14:27 -08:00
parent eef4817804
commit 31b9f65513
5 changed files with 90 additions and 61 deletions

View File

@ -6,7 +6,8 @@ Tests Home Assistant util methods.
""" """
# pylint: disable=too-many-public-methods # pylint: disable=too-many-public-methods
import unittest import unittest
from datetime import datetime import time
from datetime import datetime, timedelta
import homeassistant.util as util import homeassistant.util as util
@ -190,3 +191,33 @@ class TestUtil(unittest.TestCase):
set1.update([1, 2], [5, 6]) set1.update([1, 2], [5, 6])
self.assertEqual([2, 3, 1, 5, 6], set1) self.assertEqual([2, 3, 1, 5, 6], set1)
def test_add_cooldown(self):
""" Test the add cooldown decorator. """
calls = []
@util.AddCooldown(timedelta(milliseconds=500))
def test_cooldown():
calls.append(1)
self.assertEqual(0, len(calls))
test_cooldown()
self.assertEqual(1, len(calls))
test_cooldown()
self.assertEqual(1, len(calls))
time.sleep(.3)
test_cooldown()
self.assertEqual(1, len(calls))
time.sleep(.2)
test_cooldown()
self.assertEqual(2, len(calls))

View File

@ -1,7 +1,7 @@
""" Supports scanning a OpenWRT router. """ """ Supports scanning a OpenWRT router. """
import logging import logging
import json import json
from datetime import datetime, timedelta from datetime import timedelta
import re import re
import threading import threading
import requests import requests
@ -52,7 +52,6 @@ class LuciDeviceScanner(object):
self.lock = threading.Lock() self.lock = threading.Lock()
self.date_updated = None
self.last_results = {} self.last_results = {}
self.token = _get_token(host, username, password) self.token = _get_token(host, username, password)
@ -88,29 +87,25 @@ class LuciDeviceScanner(object):
return return
return self.mac2name.get(device, None) return self.mac2name.get(device, None)
@util.AddCooldown(MIN_TIME_BETWEEN_SCANS)
def _update_info(self): def _update_info(self):
""" Ensures the information from the Luci router is up to date. """ Ensures the information from the Luci router is up to date.
Returns boolean if scanning successful. """ Returns boolean if scanning successful. """
if not self.success_init: if not self.success_init:
return False return False
with self.lock: with self.lock:
# if date_updated is None or the date is too old we scan _LOGGER.info("Checking ARP")
# for new data
if not self.date_updated or \
datetime.now() - self.date_updated > MIN_TIME_BETWEEN_SCANS:
_LOGGER.info("Checking ARP") url = 'http://{}/cgi-bin/luci/rpc/sys'.format(self.host)
result = _req_json_rpc(url, 'net.arptable',
params={'auth': self.token})
if result:
self.last_results = [x['HW address'] for x in result]
url = 'http://{}/cgi-bin/luci/rpc/sys'.format(self.host) return True
result = _req_json_rpc(url, 'net.arptable',
params={'auth': self.token})
if result:
self.last_results = [x['HW address'] for x in result]
self.date_updated = datetime.now()
return True
return False
return True return False
def _req_json_rpc(url, method, *args, **kwargs): def _req_json_rpc(url, method, *args, **kwargs):

View File

@ -1,6 +1,6 @@
""" Supports scanning a Netgear router. """ """ Supports scanning a Netgear router. """
import logging import logging
from datetime import datetime, timedelta from datetime import timedelta
import threading import threading
import homeassistant as ha import homeassistant as ha
@ -34,7 +34,6 @@ class NetgearDeviceScanner(object):
host = config[ha.CONF_HOST] host = config[ha.CONF_HOST]
username, password = config[ha.CONF_USERNAME], config[ha.CONF_PASSWORD] username, password = config[ha.CONF_USERNAME], config[ha.CONF_PASSWORD]
self.date_updated = None
self.last_results = [] self.last_results = []
try: try:
@ -75,10 +74,6 @@ class NetgearDeviceScanner(object):
def get_device_name(self, mac): def get_device_name(self, mac):
""" Returns the name of the given device or None if we don't know. """ """ Returns the name of the given device or None if we don't know. """
# Make sure there are results
if not self.date_updated:
self._update_info()
filter_named = [device.name for device in self.last_results filter_named = [device.name for device in self.last_results
if device.mac == mac] if device.mac == mac]
@ -87,6 +82,7 @@ class NetgearDeviceScanner(object):
else: else:
return None return None
@util.AddCooldown(MIN_TIME_BETWEEN_SCANS)
def _update_info(self): def _update_info(self):
""" Retrieves latest information from the Netgear router. """ Retrieves latest information from the Netgear router.
Returns boolean if scanning successful. """ Returns boolean if scanning successful. """
@ -94,18 +90,6 @@ class NetgearDeviceScanner(object):
return return
with self.lock: with self.lock:
# if date_updated is None or the date is too old we scan for _LOGGER.info("Scanning")
# new data
if not self.date_updated or \
datetime.now() - self.date_updated > MIN_TIME_BETWEEN_SCANS:
_LOGGER.info("Scanning") self.last_results = self._api.get_attached_devices()
self.last_results = self._api.get_attached_devices()
self.date_updated = datetime.now()
return
else:
return

View File

@ -1,7 +1,7 @@
""" Supports scanning a Tomato router. """ """ Supports scanning a Tomato router. """
import logging import logging
import json import json
from datetime import datetime, timedelta from datetime import timedelta
import re import re
import threading import threading
@ -55,7 +55,6 @@ class TomatoDeviceScanner(object):
self.logger = logging.getLogger("{}.{}".format(__name__, "Tomato")) self.logger = logging.getLogger("{}.{}".format(__name__, "Tomato"))
self.lock = threading.Lock() self.lock = threading.Lock()
self.date_updated = None
self.last_results = {"wldev": [], "dhcpd_lease": []} self.last_results = {"wldev": [], "dhcpd_lease": []}
self.success_init = self._update_tomato_info() self.success_init = self._update_tomato_info()
@ -71,10 +70,6 @@ class TomatoDeviceScanner(object):
def get_device_name(self, device): def get_device_name(self, device):
""" Returns the name of the given device or None if we don't know. """ """ Returns the name of the given device or None if we don't know. """
# Make sure there are results
if not self.date_updated:
self._update_tomato_info()
filter_named = [item[0] for item in self.last_results['dhcpd_lease'] filter_named = [item[0] for item in self.last_results['dhcpd_lease']
if item[2] == device] if item[2] == device]
@ -83,16 +78,12 @@ class TomatoDeviceScanner(object):
else: else:
return filter_named[0] return filter_named[0]
@util.AddCooldown(MIN_TIME_BETWEEN_SCANS)
def _update_tomato_info(self): def _update_tomato_info(self):
""" Ensures the information from the Tomato router is up to date. """ Ensures the information from the Tomato router is up to date.
Returns boolean if scanning successful. """ Returns boolean if scanning successful. """
self.lock.acquire() with self.lock:
# if date_updated is None or the date is too old we scan for new data
if not self.date_updated or \
datetime.now() - self.date_updated > MIN_TIME_BETWEEN_SCANS:
self.logger.info("Scanning") self.logger.info("Scanning")
try: try:
@ -111,8 +102,6 @@ class TomatoDeviceScanner(object):
self.last_results[param] = \ self.last_results[param] = \
json.loads(value.replace("'", '"')) json.loads(value.replace("'", '"'))
self.date_updated = datetime.now()
return True return True
elif response.status_code == 401: elif response.status_code == 401:
@ -146,13 +135,3 @@ class TomatoDeviceScanner(object):
"Failed to parse response from router") "Failed to parse response from router")
return False return False
finally:
self.lock.release()
else:
# We acquired the lock before the IF check,
# release it before we return True
self.lock.release()
return True

View File

@ -12,6 +12,7 @@ import datetime
import re import re
import enum import enum
import socket import socket
from functools import wraps
RE_SANITIZE_FILENAME = re.compile(r'(~|\.\.|/|\\)') RE_SANITIZE_FILENAME = re.compile(r'(~|\.\.|/|\\)')
RE_SANITIZE_PATH = re.compile(r'(~|\.(\.)+)') RE_SANITIZE_PATH = re.compile(r'(~|\.(\.)+)')
@ -273,6 +274,45 @@ def validate_config(config, items, logger):
return not errors_found return not errors_found
class AddCooldown(object):
"""
A method decorator to add a cooldown to a method.
If you set a cooldown of 5 seconds. Then if you call a method twice the
underlaying method will not be called if the second call was within
5 seconds of the first. None will be returned instead.
Makes a last_call attribute available on the wrapped method.
"""
# pylint: disable=too-few-public-methods
def __init__(self, min_time):
self.min_time = min_time
def __call__(self, method):
lock = threading.Lock()
@wraps(method)
def wrapper(*args, **kwargs):
"""
Wrapper that allows wrapped to be called only once per min_time.
"""
with lock:
now = datetime.datetime.now()
last_call = wrapper.last_call
if last_call is None or now - last_call > self.min_time:
result = method(*args, **kwargs)
wrapper.last_call = now
return result
else:
return None
wrapper.last_call = None
return wrapper
# Reason why I decided to roll my own ThreadPool instead of using # Reason why I decided to roll my own ThreadPool instead of using
# multiprocessing.dummy.pool or even better, use multiprocessing.pool and # multiprocessing.dummy.pool or even better, use multiprocessing.pool and
# not be hurt by the GIL in the cpython interpreter: # not be hurt by the GIL in the cpython interpreter: