Extended test_init tests to cover all

This commit is contained in:
Paulus Schoutsen 2015-08-04 18:16:10 +02:00
parent df3ee6005a
commit 2075de3d81
2 changed files with 266 additions and 37 deletions

View File

@ -30,9 +30,9 @@ class TestEventHelpers(unittest.TestCase):
def test_track_point_in_time(self): def test_track_point_in_time(self):
""" Test track point in time. """ """ Test track point in time. """
before_birthday = datetime(1985, 7, 9, 12, 0, 0) before_birthday = datetime(1985, 7, 9, 12, 0, 0, tzinfo=dt_util.UTC)
birthday_paulus = datetime(1986, 7, 9, 12, 0, 0) birthday_paulus = datetime(1986, 7, 9, 12, 0, 0, tzinfo=dt_util.UTC)
after_birthday = datetime(1987, 7, 9, 12, 0, 0) after_birthday = datetime(1987, 7, 9, 12, 0, 0, tzinfo=dt_util.UTC)
runs = [] runs = []
@ -52,7 +52,7 @@ class TestEventHelpers(unittest.TestCase):
self.hass.pool.block_till_done() self.hass.pool.block_till_done()
self.assertEqual(1, len(runs)) self.assertEqual(1, len(runs))
track_point_in_utc_time( track_point_in_time(
self.hass, lambda x: runs.append(1), birthday_paulus) self.hass, lambda x: runs.append(1), birthday_paulus)
self._send_time_changed(after_birthday) self._send_time_changed(after_birthday)
@ -65,7 +65,7 @@ class TestEventHelpers(unittest.TestCase):
specific_runs = [] specific_runs = []
track_time_change(self.hass, lambda x: wildcard_runs.append(1)) track_time_change(self.hass, lambda x: wildcard_runs.append(1))
track_time_change( track_utc_time_change(
self.hass, lambda x: specific_runs.append(1), second=[0, 30]) self.hass, lambda x: specific_runs.append(1), second=[0, 30])
self._send_time_changed(datetime(2014, 5, 24, 12, 0, 0)) self._send_time_changed(datetime(2014, 5, 24, 12, 0, 0))
@ -84,7 +84,7 @@ class TestEventHelpers(unittest.TestCase):
self.assertEqual(3, len(wildcard_runs)) self.assertEqual(3, len(wildcard_runs))
def test_track_state_change(self): def test_track_state_change(self):
""" Test states.track_change. """ """ Test track_state_change. """
# 2 lists to track how often our callbacks get called # 2 lists to track how often our callbacks get called
specific_runs = [] specific_runs = []
wildcard_runs = [] wildcard_runs = []

View File

@ -8,18 +8,27 @@ Provides tests to verify that Home Assistant core works.
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
import os import os
import unittest import unittest
import unittest.mock as mock
import time import time
import threading import threading
from datetime import datetime from datetime import datetime
import pytz
import homeassistant as ha import homeassistant as ha
import homeassistant.util.dt as dt_util
from homeassistant.helpers.event import track_state_change
from homeassistant.const import (
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
ATTR_FRIENDLY_NAME, TEMP_CELCIUS,
TEMP_FAHRENHEIT)
PST = pytz.timezone('America/Los_Angeles')
class TestHomeAssistant(unittest.TestCase): class TestHomeAssistant(unittest.TestCase):
""" """
Tests the Home Assistant core classes. Tests the Home Assistant core classes.
Currently only includes tests to test cases that do not
get tested in the API integration tests.
""" """
def setUp(self): # pylint: disable=invalid-name def setUp(self): # pylint: disable=invalid-name
@ -36,13 +45,13 @@ class TestHomeAssistant(unittest.TestCase):
# Already stopped after the block till stopped test # Already stopped after the block till stopped test
pass pass
def test_get_config_path(self): def test_start(self):
""" Test get_config_path method. """ calls = []
self.assertEqual(os.path.join(os.getcwd(), "config"), self.hass.bus.listen_once(EVENT_HOMEASSISTANT_START,
self.hass.config.config_dir) lambda event: calls.append(1))
self.hass.start()
self.assertEqual(os.path.join(os.getcwd(), "config", "test.conf"), self.hass.pool.block_till_done()
self.hass.config.path("test.conf")) self.assertEqual(1, len(calls))
def test_block_till_stoped(self): def test_block_till_stoped(self):
""" Test if we can block till stop service is called. """ """ Test if we can block till stop service is called. """
@ -51,28 +60,48 @@ class TestHomeAssistant(unittest.TestCase):
self.assertFalse(blocking_thread.is_alive()) self.assertFalse(blocking_thread.is_alive())
blocking_thread.start() blocking_thread.start()
# Python will now give attention to the other thread
time.sleep(1) # Threads are unpredictable, try 20 times if we're ready
wait_loops = 0
while not blocking_thread.is_alive() and wait_loops < 20:
wait_loops += 1
time.sleep(0.05)
self.assertTrue(blocking_thread.is_alive()) self.assertTrue(blocking_thread.is_alive())
self.hass.services.call(ha.DOMAIN, ha.SERVICE_HOMEASSISTANT_STOP) self.hass.services.call(ha.DOMAIN, ha.SERVICE_HOMEASSISTANT_STOP)
self.hass.pool.block_till_done() self.hass.pool.block_till_done()
# hass.block_till_stopped checks every second if it should quit # Threads are unpredictable, try 20 times if we're ready
# we have to wait worst case 1 second
wait_loops = 0 wait_loops = 0
while blocking_thread.is_alive() and wait_loops < 50: while blocking_thread.is_alive() and wait_loops < 20:
wait_loops += 1 wait_loops += 1
time.sleep(0.1) time.sleep(0.05)
self.assertFalse(blocking_thread.is_alive()) self.assertFalse(blocking_thread.is_alive())
def test_stopping_with_keyboardinterrupt(self):
calls = []
self.hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP,
lambda event: calls.append(1))
def raise_keyboardinterrupt(length):
# We don't want to patch the sleep of the timer.
if length == 1:
raise KeyboardInterrupt
self.hass.start()
with mock.patch('time.sleep', raise_keyboardinterrupt):
self.hass.block_till_stopped()
self.assertEqual(1, len(calls))
def test_track_point_in_time(self): def test_track_point_in_time(self):
""" Test track point in time. """ """ Test track point in time. """
before_birthday = datetime(1985, 7, 9, 12, 0, 0) before_birthday = datetime(1985, 7, 9, 12, 0, 0, tzinfo=dt_util.UTC)
birthday_paulus = datetime(1986, 7, 9, 12, 0, 0) birthday_paulus = datetime(1986, 7, 9, 12, 0, 0, tzinfo=dt_util.UTC)
after_birthday = datetime(1987, 7, 9, 12, 0, 0) after_birthday = datetime(1987, 7, 9, 12, 0, 0, tzinfo=dt_util.UTC)
runs = [] runs = []
@ -92,7 +121,7 @@ class TestHomeAssistant(unittest.TestCase):
self.hass.pool.block_till_done() self.hass.pool.block_till_done()
self.assertEqual(1, len(runs)) self.assertEqual(1, len(runs))
self.hass.track_point_in_utc_time( self.hass.track_point_in_time(
lambda x: runs.append(1), birthday_paulus) lambda x: runs.append(1), birthday_paulus)
self._send_time_changed(after_birthday) self._send_time_changed(after_birthday)
@ -105,7 +134,7 @@ class TestHomeAssistant(unittest.TestCase):
specific_runs = [] specific_runs = []
self.hass.track_time_change(lambda x: wildcard_runs.append(1)) self.hass.track_time_change(lambda x: wildcard_runs.append(1))
self.hass.track_time_change( self.hass.track_utc_time_change(
lambda x: specific_runs.append(1), second=[0, 30]) lambda x: specific_runs.append(1), second=[0, 30])
self._send_time_changed(datetime(2014, 5, 24, 12, 0, 0)) self._send_time_changed(datetime(2014, 5, 24, 12, 0, 0))
@ -130,6 +159,16 @@ class TestHomeAssistant(unittest.TestCase):
class TestEvent(unittest.TestCase): class TestEvent(unittest.TestCase):
""" Test Event class. """ """ Test Event class. """
def test_eq(self):
now = dt_util.utcnow()
data = {'some': 'attr'}
event1, event2 = [
ha.Event('some_type', data, time_fired=now)
for _ in range(2)
]
self.assertEqual(event1, event2)
def test_repr(self): def test_repr(self):
""" Test that repr method works. #MoreCoverage """ """ Test that repr method works. #MoreCoverage """
self.assertEqual( self.assertEqual(
@ -142,13 +181,27 @@ class TestEvent(unittest.TestCase):
{"beer": "nice"}, {"beer": "nice"},
ha.EventOrigin.remote))) ha.EventOrigin.remote)))
def test_as_dict(self):
event_type = 'some_type'
now = dt_util.utcnow()
data = {'some': 'attr'}
event = ha.Event(event_type, data, ha.EventOrigin.local, now)
expected = {
'event_type': event_type,
'data': data,
'origin': 'LOCAL',
'time_fired': dt_util.datetime_to_str(now),
}
self.assertEqual(expected, event.as_dict())
class TestEventBus(unittest.TestCase): class TestEventBus(unittest.TestCase):
""" Test EventBus methods. """ """ Test EventBus methods. """
def setUp(self): # pylint: disable=invalid-name def setUp(self): # pylint: disable=invalid-name
""" things to be run when tests are started. """ """ things to be run when tests are started. """
self.bus = ha.EventBus() self.bus = ha.EventBus(ha.create_worker_pool(0))
self.bus.listen('test_event', lambda x: len) self.bus.listen('test_event', lambda x: len)
def tearDown(self): # pylint: disable=invalid-name def tearDown(self): # pylint: disable=invalid-name
@ -157,6 +210,7 @@ class TestEventBus(unittest.TestCase):
def test_add_remove_listener(self): def test_add_remove_listener(self):
""" Test remove_listener method. """ """ Test remove_listener method. """
self.bus._pool.add_worker()
old_count = len(self.bus.listeners) old_count = len(self.bus.listeners)
listener = lambda x: len listener = lambda x: len
@ -182,11 +236,10 @@ class TestEventBus(unittest.TestCase):
self.bus.listen_once('test_event', lambda x: runs.append(1)) self.bus.listen_once('test_event', lambda x: runs.append(1))
self.bus.fire('test_event') self.bus.fire('test_event')
self.bus._pool.block_till_done()
self.assertEqual(1, len(runs))
# Second time it should not increase runs # Second time it should not increase runs
self.bus.fire('test_event') self.bus.fire('test_event')
self.bus._pool.add_worker()
self.bus._pool.block_till_done() self.bus._pool.block_till_done()
self.assertEqual(1, len(runs)) self.assertEqual(1, len(runs))
@ -200,6 +253,37 @@ class TestState(unittest.TestCase):
ha.InvalidEntityFormatError, ha.State, ha.InvalidEntityFormatError, ha.State,
'invalid_entity_format', 'test_state') 'invalid_entity_format', 'test_state')
def test_domain(self):
state = ha.State('some_domain.hello', 'world')
self.assertEqual('some_domain', state.domain)
def test_object_id(self):
state = ha.State('domain.hello', 'world')
self.assertEqual('hello', state.object_id)
def test_name_if_no_friendly_name_attr(self):
state = ha.State('domain.hello_world', 'world')
self.assertEqual('hello world', state.name)
def test_name_if_friendly_name_attr(self):
name = 'Some Unique Name'
state = ha.State('domain.hello_world', 'world',
{ATTR_FRIENDLY_NAME: name})
self.assertEqual(name, state.name)
def test_copy(self):
state = ha.State('domain.hello', 'world', {'some': 'attr'})
self.assertEqual(state, state.copy())
def test_dict_conversion(self):
state = ha.State('domain.hello', 'world', {'some': 'attr'})
self.assertEqual(state, ha.State.from_dict(state.as_dict()))
def test_dict_conversion_with_wrong_data(self):
self.assertIsNone(ha.State.from_dict(None))
self.assertIsNone(ha.State.from_dict({'state': 'yes'}))
self.assertIsNone(ha.State.from_dict({'entity_id': 'yes'}))
def test_repr(self): def test_repr(self):
""" Test state.repr """ """ Test state.repr """
self.assertEqual("<state happy.happy=on @ 12:00:00 08-12-1984>", self.assertEqual("<state happy.happy=on @ 12:00:00 08-12-1984>",
@ -218,14 +302,15 @@ class TestStateMachine(unittest.TestCase):
def setUp(self): # pylint: disable=invalid-name def setUp(self): # pylint: disable=invalid-name
""" things to be run when tests are started. """ """ things to be run when tests are started. """
self.bus = ha.EventBus() self.pool = ha.create_worker_pool(0)
self.bus = ha.EventBus(self.pool)
self.states = ha.StateMachine(self.bus) self.states = ha.StateMachine(self.bus)
self.states.set("light.Bowl", "on") self.states.set("light.Bowl", "on")
self.states.set("switch.AC", "off") self.states.set("switch.AC", "off")
def tearDown(self): # pylint: disable=invalid-name def tearDown(self): # pylint: disable=invalid-name
""" Stop down stuff we started. """ """ Stop down stuff we started. """
self.bus._pool.stop() self.pool.stop()
def test_is_state(self): def test_is_state(self):
""" Test is_state method. """ """ Test is_state method. """
@ -244,6 +329,10 @@ class TestStateMachine(unittest.TestCase):
self.assertEqual(1, len(ent_ids)) self.assertEqual(1, len(ent_ids))
self.assertTrue('light.bowl' in ent_ids) self.assertTrue('light.bowl' in ent_ids)
def test_all(self):
states = sorted(state.entity_id for state in self.states.all())
self.assertEqual(['light.bowl', 'switch.ac'], states)
def test_remove(self): def test_remove(self):
""" Test remove method. """ """ Test remove method. """
self.assertTrue('light.bowl' in self.states.entity_ids()) self.assertTrue('light.bowl' in self.states.entity_ids())
@ -255,6 +344,8 @@ class TestStateMachine(unittest.TestCase):
def test_track_change(self): def test_track_change(self):
""" Test states.track_change. """ """ Test states.track_change. """
self.pool.add_worker()
# 2 lists to track how often our callbacks got called # 2 lists to track how often our callbacks got called
specific_runs = [] specific_runs = []
wildcard_runs = [] wildcard_runs = []
@ -291,10 +382,11 @@ class TestStateMachine(unittest.TestCase):
self.assertEqual(3, len(wildcard_runs)) self.assertEqual(3, len(wildcard_runs))
def test_case_insensitivty(self): def test_case_insensitivty(self):
self.pool.add_worker()
runs = [] runs = []
self.states.track_change( track_state_change(
'light.BoWl', lambda a, b, c: runs.append(1), ha._MockHA(self.bus), 'light.BoWl', lambda a, b, c: runs.append(1),
ha.MATCH_ALL, ha.MATCH_ALL) ha.MATCH_ALL, ha.MATCH_ALL)
self.states.set('light.BOWL', 'off') self.states.set('light.BOWL', 'off')
@ -332,16 +424,153 @@ class TestServiceRegistry(unittest.TestCase):
def setUp(self): # pylint: disable=invalid-name def setUp(self): # pylint: disable=invalid-name
""" things to be run when tests are started. """ """ things to be run when tests are started. """
self.pool = ha.create_worker_pool() self.pool = ha.create_worker_pool(0)
self.bus = ha.EventBus(self.pool) self.bus = ha.EventBus(self.pool)
self.services = ha.ServiceRegistry(self.bus, self.pool) self.services = ha.ServiceRegistry(self.bus, self.pool)
self.services.register("test_domain", "test_service", lambda x: len) self.services.register("test_domain", "test_service", lambda x: None)
def tearDown(self): # pylint: disable=invalid-name def tearDown(self): # pylint: disable=invalid-name
""" Stop down stuff we started. """ """ Stop down stuff we started. """
self.pool.stop() if self.pool.worker_count:
self.pool.stop()
def test_has_service(self): def test_has_service(self):
""" Test has_service method. """ """ Test has_service method. """
self.assertTrue( self.assertTrue(
self.services.has_service("test_domain", "test_service")) self.services.has_service("test_domain", "test_service"))
self.assertFalse(
self.services.has_service("test_domain", "non_existing"))
self.assertFalse(
self.services.has_service("non_existing", "test_service"))
def test_services(self):
expected = {
'test_domain': ['test_service']
}
self.assertEqual(expected, self.services.services)
def test_call_with_blocking_done_in_time(self):
self.pool.add_worker()
self.pool.add_worker()
calls = []
self.services.register("test_domain", "register_calls",
lambda x: calls.append(1))
self.assertTrue(
self.services.call('test_domain', 'register_calls', blocking=True))
self.assertEqual(1, len(calls))
def test_call_with_blocking_not_done_in_time(self):
calls = []
self.services.register("test_domain", "register_calls",
lambda x: calls.append(1))
orig_limit = ha.SERVICE_CALL_LIMIT
ha.SERVICE_CALL_LIMIT = 0.01
self.assertFalse(
self.services.call('test_domain', 'register_calls', blocking=True))
self.assertEqual(0, len(calls))
ha.SERVICE_CALL_LIMIT = orig_limit
def test_call_non_existing_with_blocking(self):
self.pool.add_worker()
self.pool.add_worker()
orig_limit = ha.SERVICE_CALL_LIMIT
ha.SERVICE_CALL_LIMIT = 0.01
self.assertFalse(
self.services.call('test_domain', 'i_do_not_exist', blocking=True))
ha.SERVICE_CALL_LIMIT = orig_limit
class TestConfig(unittest.TestCase):
def setUp(self): # pylint: disable=invalid-name
""" things to be run when tests are started. """
self.config = ha.Config()
def test_config_dir_set_correct(self):
""" Test config dir set correct. """
self.assertEqual(os.path.join(os.getcwd(), "config"),
self.config.config_dir)
def test_path_with_file(self):
""" Test get_config_path method. """
self.assertEqual(os.path.join(os.getcwd(), "config", "test.conf"),
self.config.path("test.conf"))
def test_path_with_dir_and_file(self):
""" Test get_config_path method. """
self.assertEqual(
os.path.join(os.getcwd(), "config", "dir", "test.conf"),
self.config.path("dir", "test.conf"))
def test_temperature_not_convert_if_no_preference(self):
""" No unit conversion to happen if no preference. """
self.assertEqual(
(25, TEMP_CELCIUS),
self.config.temperature(25, TEMP_CELCIUS))
self.assertEqual(
(80, TEMP_FAHRENHEIT),
self.config.temperature(80, TEMP_FAHRENHEIT))
def test_temperature_not_convert_if_invalid_value(self):
""" No unit conversion to happen if no preference. """
self.config.temperature_unit = TEMP_FAHRENHEIT
self.assertEqual(
('25a', TEMP_CELCIUS),
self.config.temperature('25a', TEMP_CELCIUS))
def test_temperature_not_convert_if_invalid_unit(self):
""" No unit conversion to happen if no preference. """
self.assertEqual(
(25, 'Invalid unit'),
self.config.temperature(25, 'Invalid unit'))
def test_temperature_to_convert_to_celcius(self):
self.config.temperature_unit = TEMP_CELCIUS
self.assertEqual(
(25, TEMP_CELCIUS),
self.config.temperature(25, TEMP_CELCIUS))
self.assertEqual(
(26.7, TEMP_CELCIUS),
self.config.temperature(80, TEMP_FAHRENHEIT))
def test_temperature_to_convert_to_fahrenheit(self):
self.config.temperature_unit = TEMP_FAHRENHEIT
self.assertEqual(
(77, TEMP_FAHRENHEIT),
self.config.temperature(25, TEMP_CELCIUS))
self.assertEqual(
(80, TEMP_FAHRENHEIT),
self.config.temperature(80, TEMP_FAHRENHEIT))
def test_as_dict(self):
expected = {
'latitude': None,
'longitude': None,
'temperature_unit': None,
'location_name': None,
'time_zone': 'UTC',
'components': [],
}
self.assertEqual(expected, self.config.as_dict())
class TestWorkerPool(unittest.TestCase):
def test_exception_during_job(self):
pool = ha.create_worker_pool(1)
def malicious_job(_):
raise Exception("Test breaking worker pool")
calls = []
def register_call(_):
calls.append(1)
pool.add_job(ha.JobPriority.EVENT_DEFAULT, (malicious_job, None))
pool.add_job(ha.JobPriority.EVENT_DEFAULT, (register_call, None))
pool.block_till_done()
self.assertEqual(1, len(calls))