Core cleanup: two stage shutdown (#5876)

* Core cleanup: two stage shutdown

* fix spell

* fix

* add async logger to close

* change aiohttp to use CLOSE

* address paulus comments

* Fix tests

* Add unittest
This commit is contained in:
Pascal Vizeli 2017-02-13 06:24:07 +01:00 committed by GitHub
parent 4623d1071e
commit 41849eab06
8 changed files with 97 additions and 107 deletions

View File

@ -16,6 +16,7 @@ import homeassistant.components as core_components
from homeassistant.components import persistent_notification
import homeassistant.config as conf_util
import homeassistant.core as core
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE
import homeassistant.loader as loader
import homeassistant.util.package as pkg_util
from homeassistant.util.async import (
@ -386,7 +387,7 @@ def async_from_config_dict(config: Dict[str, Any],
None, conf_util.process_ha_config_upgrade, hass)
if enable_log:
enable_logging(hass, verbose, log_rotate_days)
async_enable_logging(hass, verbose, log_rotate_days)
hass.config.skip_pip = skip_pip
if skip_pip:
@ -488,7 +489,7 @@ def async_from_config_file(config_path: str,
yield from hass.loop.run_in_executor(
None, mount_local_lib_path, config_dir)
enable_logging(hass, verbose, log_rotate_days)
async_enable_logging(hass, verbose, log_rotate_days)
try:
config_dict = yield from hass.loop.run_in_executor(
@ -503,11 +504,12 @@ def async_from_config_file(config_path: str,
return hass
def enable_logging(hass: core.HomeAssistant, verbose: bool=False,
log_rotate_days=None) -> None:
@core.callback
def async_enable_logging(hass: core.HomeAssistant, verbose: bool=False,
log_rotate_days=None) -> None:
"""Setup the logging.
Async friendly.
This method must be run in the event loop.
"""
logging.basicConfig(level=logging.INFO)
fmt = ("%(asctime)s %(levelname)s (%(threadName)s) "
@ -537,10 +539,6 @@ def enable_logging(hass: core.HomeAssistant, verbose: bool=False,
except ImportError:
pass
# AsyncHandler allready exists?
if hass.data.get(core.DATA_ASYNCHANDLER):
return
# Log errors to a file if we have write access to file or config dir
err_log_path = hass.config.path(ERROR_LOG_FILENAME)
err_path_exists = os.path.isfile(err_log_path)
@ -561,7 +559,15 @@ def enable_logging(hass: core.HomeAssistant, verbose: bool=False,
err_handler.setFormatter(logging.Formatter(fmt, datefmt=datefmt))
async_handler = AsyncHandler(hass.loop, err_handler)
hass.data[core.DATA_ASYNCHANDLER] = async_handler
@asyncio.coroutine
def async_stop_async_handler(event):
"""Cleanup async handler."""
logging.getLogger('').removeHandler(async_handler)
yield from async_handler.async_close(blocking=True)
hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_CLOSE, async_stop_async_handler)
logger = logging.getLogger('')
logger.addHandler(async_handler)

View File

@ -156,6 +156,7 @@ CONF_ZONE = 'zone'
# #### EVENTS ####
EVENT_HOMEASSISTANT_START = 'homeassistant_start'
EVENT_HOMEASSISTANT_STOP = 'homeassistant_stop'
EVENT_HOMEASSISTANT_CLOSE = 'homeassistant_close'
EVENT_STATE_CHANGED = 'state_changed'
EVENT_TIME_CHANGED = 'time_changed'
EVENT_CALL_SERVICE = 'call_service'

View File

@ -26,7 +26,7 @@ from homeassistant.const import (
ATTR_SERVICE_CALL_ID, ATTR_SERVICE_DATA, EVENT_CALL_SERVICE,
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
EVENT_SERVICE_EXECUTED, EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED,
EVENT_TIME_CHANGED, MATCH_ALL, __version__)
EVENT_TIME_CHANGED, MATCH_ALL, EVENT_HOMEASSISTANT_CLOSE, __version__)
from homeassistant.exceptions import (
HomeAssistantError, InvalidEntityFormatError, ShuttingDown)
from homeassistant.util.async import (
@ -53,8 +53,6 @@ ENTITY_ID_PATTERN = re.compile(r"^(\w+)\.(\w+)$")
# Size of a executor pool
EXECUTOR_POOL_SIZE = 10
# AsyncHandler for logging
DATA_ASYNCHANDLER = 'log_asynchandler'
_LOGGER = logging.getLogger(__name__)
@ -279,23 +277,17 @@ class HomeAssistant(object):
This method is a coroutine.
"""
import homeassistant.helpers.aiohttp_client as aiohttp_client
# stage 1
self.state = CoreState.stopping
self.async_track_tasks()
self.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
yield from self.async_block_till_done()
self.executor.shutdown()
# stage 2
self.state = CoreState.not_running
# cleanup connector pool from aiohttp
yield from aiohttp_client.async_cleanup_websession(self)
# cleanup async layer from python logging
if self.data.get(DATA_ASYNCHANDLER):
handler = self.data.pop(DATA_ASYNCHANDLER)
logging.getLogger('').removeHandler(handler)
yield from handler.async_close(blocking=True)
self.bus.async_fire(EVENT_HOMEASSISTANT_CLOSE)
yield from self.async_block_till_done()
self.executor.shutdown()
self.exit_code = exit_code
self.loop.stop()
@ -397,11 +389,11 @@ class EventBus(object):
self._hass.state == CoreState.stopping:
raise ShuttingDown("Home Assistant is shutting down")
# Copy the list of the current listeners because some listeners
# remove themselves as a listener while being executed which
# causes the iterator to be confused.
get = self._listeners.get
listeners = get(MATCH_ALL, []) + get(event_type, [])
listeners = self._listeners.get(event_type, [])
# EVENT_HOMEASSISTANT_CLOSE should go only to his listeners
if event_type != EVENT_HOMEASSISTANT_CLOSE:
listeners = self._listeners.get(MATCH_ALL, []) + listeners
event = Event(event_type, event_data, origin)

View File

@ -9,7 +9,7 @@ from aiohttp.web_exceptions import HTTPGatewayTimeout
import async_timeout
from homeassistant.core import callback
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE
from homeassistant.const import __version__
DATA_CONNECTOR = 'aiohttp_connector'
@ -38,6 +38,7 @@ def async_get_clientsession(hass, verify_ssl=True):
connector=connector,
headers={USER_AGENT: SERVER_SOFTWARE}
)
_async_register_clientsession_shutdown(hass, clientsession)
hass.data[key] = clientsession
return hass.data[key]
@ -121,7 +122,7 @@ def _async_register_clientsession_shutdown(hass, clientsession):
clientsession.detach()
hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STOP, _async_close_websession)
EVENT_HOMEASSISTANT_CLOSE, _async_close_websession)
@callback
@ -130,37 +131,30 @@ def _async_get_connector(hass, verify_ssl=True):
This method must be run in the event loop.
"""
is_new = False
if verify_ssl:
if DATA_CONNECTOR not in hass.data:
connector = aiohttp.TCPConnector(loop=hass.loop)
hass.data[DATA_CONNECTOR] = connector
is_new = True
else:
connector = hass.data[DATA_CONNECTOR]
else:
if DATA_CONNECTOR_NOTVERIFY not in hass.data:
connector = aiohttp.TCPConnector(loop=hass.loop, verify_ssl=False)
hass.data[DATA_CONNECTOR_NOTVERIFY] = connector
is_new = True
else:
connector = hass.data[DATA_CONNECTOR_NOTVERIFY]
if is_new:
@asyncio.coroutine
def _async_close_connector(event):
"""Close connector pool."""
yield from connector.close()
hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_CLOSE, _async_close_connector)
return connector
@asyncio.coroutine
def async_cleanup_websession(hass):
"""Cleanup aiohttp connector pool.
This method is a coroutine.
"""
tasks = []
if DATA_CLIENTSESSION in hass.data:
hass.data[DATA_CLIENTSESSION].detach()
if DATA_CONNECTOR in hass.data:
tasks.append(hass.data[DATA_CONNECTOR].close())
if DATA_CLIENTSESSION_NOTVERIFY in hass.data:
hass.data[DATA_CLIENTSESSION_NOTVERIFY].detach()
if DATA_CONNECTOR_NOTVERIFY in hass.data:
tasks.append(hass.data[DATA_CONNECTOR_NOTVERIFY].close())
if tasks:
yield from asyncio.wait(tasks, loop=hass.loop)

View File

@ -1,7 +1,6 @@
"""The tests for the MQTT eventstream component."""
from collections import namedtuple
import json
import unittest
from unittest.mock import ANY, patch
from homeassistant.bootstrap import setup_component
@ -21,16 +20,15 @@ from tests.common import (
)
class TestMqttEventStream(unittest.TestCase):
class TestMqttEventStream(object):
"""Test the MQTT eventstream module."""
def setUp(self): # pylint: disable=invalid-name
def setup_method(self):
"""Setup things to be run when tests are started."""
super(TestMqttEventStream, self).setUp()
self.hass = get_test_home_assistant()
self.mock_mqtt = mock_mqtt_component(self.hass)
def tearDown(self): # pylint: disable=invalid-name
def teardown_method(self):
"""Stop everything that was started."""
self.hass.stop()
@ -46,24 +44,24 @@ class TestMqttEventStream(unittest.TestCase):
def test_setup_succeeds(self):
""""Test the success of the setup."""
self.assertTrue(self.add_eventstream())
assert self.add_eventstream()
def test_setup_with_pub(self):
""""Test the setup with subscription."""
# Should start off with no listeners for all events
self.assertEqual(self.hass.bus.listeners.get('*'), None)
assert self.hass.bus.listeners.get('*') is None
self.assertTrue(self.add_eventstream(pub_topic='bar'))
assert self.add_eventstream(pub_topic='bar')
self.hass.block_till_done()
# Verify that the event handler has been added as a listener
self.assertEqual(self.hass.bus.listeners.get('*'), 1)
assert self.hass.bus.listeners.get('*') == 1
@patch('homeassistant.components.mqtt.subscribe')
def test_subscribe(self, mock_sub):
""""Test the subscription."""
sub_topic = 'foo'
self.assertTrue(self.add_eventstream(sub_topic=sub_topic))
assert self.add_eventstream(sub_topic=sub_topic)
self.hass.block_till_done()
# Verify that the this entity was subscribed to the topic
@ -79,7 +77,7 @@ class TestMqttEventStream(unittest.TestCase):
mock_utcnow.return_value = now
# Add the eventstream component for publishing events
self.assertTrue(self.add_eventstream(pub_topic=pub_topic))
assert self.add_eventstream(pub_topic=pub_topic)
self.hass.block_till_done()
# Reset the mock because it will have already gotten calls for the
@ -93,7 +91,7 @@ class TestMqttEventStream(unittest.TestCase):
# The order of the JSON is indeterminate,
# so first just check that publish was called
mock_pub.assert_called_with(self.hass, pub_topic, ANY)
self.assertTrue(mock_pub.called)
assert mock_pub.called
# Get the actual call to publish and make sure it was the one
# we were looking for
@ -110,12 +108,12 @@ class TestMqttEventStream(unittest.TestCase):
event['event_data'] = {"new_state": new_state, "entity_id": e_id}
# Verify that the message received was that expected
self.assertEqual(json.loads(msg), event)
assert json.loads(msg) == event
@patch('homeassistant.components.mqtt.publish')
def test_time_event_does_not_send_message(self, mock_pub):
""""Test the sending of a new message if time event."""
self.assertTrue(self.add_eventstream(pub_topic='bar'))
assert self.add_eventstream(pub_topic='bar')
self.hass.block_till_done()
# Reset the mock because it will have already gotten calls for the
@ -123,12 +121,12 @@ class TestMqttEventStream(unittest.TestCase):
mock_pub.reset_mock()
fire_time_changed(self.hass, dt_util.utcnow())
self.assertFalse(mock_pub.called)
assert not mock_pub.called
def test_receiving_remote_event_fires_hass_event(self):
""""Test the receiving of the remotely fired event."""
sub_topic = 'foo'
self.assertTrue(self.add_eventstream(sub_topic=sub_topic))
assert self.add_eventstream(sub_topic=sub_topic)
self.hass.block_till_done()
calls = []
@ -147,7 +145,7 @@ class TestMqttEventStream(unittest.TestCase):
fire_mqtt_message(self.hass, sub_topic, payload)
self.hass.block_till_done()
self.assertEqual(1, len(calls))
assert 1 == len(calls)
@patch('homeassistant.components.mqtt.publish')
def test_mqtt_received_event(self, mock_pub):
@ -160,10 +158,9 @@ class TestMqttEventStream(unittest.TestCase):
"""
SUB_TOPIC = 'from_slaves'
self.assertTrue(
self.add_eventstream(
assert self.add_eventstream(
pub_topic='bar',
sub_topic=SUB_TOPIC))
sub_topic=SUB_TOPIC)
self.hass.block_till_done()
# Reset the mock because it will have already gotten calls for the
@ -173,19 +170,21 @@ class TestMqttEventStream(unittest.TestCase):
# Use MQTT component message handler to simulate firing message
# received event.
MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload'])
message = MQTTMessage(SUB_TOPIC, 1, 'Hello World!'.encode('utf-8'))
message = MQTTMessage(
SUB_TOPIC, 1, '{"test": "Hello World!"}'.encode('utf-8'))
mqtt.MQTT._mqtt_on_message(self, None, {'hass': self.hass}, message)
self.hass.block_till_done()
# 'normal' incoming mqtt messages should be broadcasted
self.assertEqual(mock_pub.call_count, 0)
assert mock_pub.call_count == 0
MQTTMessage = namedtuple('MQTTMessage', ['topic', 'qos', 'payload'])
message = MQTTMessage('test_topic', 1, 'Hello World!'.encode('utf-8'))
message = MQTTMessage(
'test_topic', 1, '{"test": "Hello World!"}'.encode('utf-8'))
mqtt.MQTT._mqtt_on_message(self, None, {'hass': self.hass}, message)
self.hass.block_till_done()
# but event from the event stream not
self.assertEqual(mock_pub.call_count, 1)
assert mock_pub.call_count == 1

View File

@ -4,10 +4,10 @@ import unittest
import aiohttp
from homeassistant.core import EVENT_HOMEASSISTANT_CLOSE
from homeassistant.bootstrap import setup_component
import homeassistant.helpers.aiohttp_client as client
from homeassistant.util.async import (
run_callback_threadsafe, run_coroutine_threadsafe)
from homeassistant.util.async import run_callback_threadsafe
from tests.common import get_test_home_assistant
@ -93,9 +93,8 @@ class TestHelpersAiohttpClient(unittest.TestCase):
assert isinstance(
self.hass.data[client.DATA_CONNECTOR], aiohttp.TCPConnector)
run_coroutine_threadsafe(
client.async_cleanup_websession(self.hass), self.hass.loop
).result()
self.hass.bus.fire(EVENT_HOMEASSISTANT_CLOSE)
self.hass.block_till_done()
assert self.hass.data[client.DATA_CLIENTSESSION].closed
assert self.hass.data[client.DATA_CONNECTOR].closed
@ -112,9 +111,8 @@ class TestHelpersAiohttpClient(unittest.TestCase):
self.hass.data[client.DATA_CONNECTOR_NOTVERIFY],
aiohttp.TCPConnector)
run_coroutine_threadsafe(
client.async_cleanup_websession(self.hass), self.hass.loop
).result()
self.hass.bus.fire(EVENT_HOMEASSISTANT_CLOSE)
self.hass.block_till_done()
assert self.hass.data[client.DATA_CLIENTSESSION_NOTVERIFY].closed
assert self.hass.data[client.DATA_CONNECTOR_NOTVERIFY].closed

View File

@ -71,7 +71,7 @@ class TestBootstrap:
with mock.patch('os.path.isfile', mock.Mock(return_value=True)), \
mock.patch('os.access', mock.Mock(return_value=True)), \
mock.patch('homeassistant.bootstrap.enable_logging',
mock.patch('homeassistant.bootstrap.async_enable_logging',
mock.Mock(return_value=True)), \
patch_yaml_files(files, True):
self.hass = bootstrap.from_config_file('config.yaml')
@ -289,7 +289,7 @@ class TestBootstrap:
assert not bootstrap.setup_component(self.hass, 'comp', {})
assert 'comp' not in self.hass.config.components
@mock.patch('homeassistant.bootstrap.enable_logging')
@mock.patch('homeassistant.bootstrap.async_enable_logging')
@mock.patch('homeassistant.bootstrap.async_register_signal_handling')
def test_home_assistant_core_config_validation(self, log_mock, sig_mock):
"""Test if we pass in wrong information for HA conf."""

View File

@ -16,7 +16,7 @@ from homeassistant.util.unit_system import (METRIC_SYSTEM)
from homeassistant.const import (
__version__, EVENT_STATE_CHANGED, ATTR_FRIENDLY_NAME, CONF_UNIT_SYSTEM,
ATTR_NOW, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP,
EVENT_HOMEASSISTANT_START)
EVENT_HOMEASSISTANT_CLOSE, EVENT_HOMEASSISTANT_START)
from tests.common import get_test_home_assistant
@ -89,6 +89,26 @@ def test_async_run_job_delegates_non_async():
assert len(hass.async_add_job.mock_calls) == 1
def test_stage_shutdown():
"""Simulate a shutdown, test calling stuff."""
hass = get_test_home_assistant()
test_stop = []
test_close = []
test_all = []
hass.bus.listen(
EVENT_HOMEASSISTANT_STOP, lambda event: test_stop.append(event))
hass.bus.listen(
EVENT_HOMEASSISTANT_CLOSE, lambda event: test_close.append(event))
hass.bus.listen('*', lambda event: test_all.append(event))
hass.stop()
assert len(test_stop) == 1
assert len(test_close) == 1
assert len(test_all) == 1
class TestHomeAssistant(unittest.TestCase):
"""Test the Home Assistant core classes."""
@ -102,26 +122,6 @@ class TestHomeAssistant(unittest.TestCase):
"""Stop everything that was started."""
self.hass.stop()
# This test hangs on `loop.add_signal_handler`
# def test_start_and_sigterm(self):
# """Start the test."""
# calls = []
# self.hass.bus.listen_once(EVENT_HOMEASSISTANT_START,
# lambda event: calls.append(1))
# self.hass.start()
# self.assertEqual(1, len(calls))
# self.hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP,
# lambda event: calls.append(1))
# os.kill(os.getpid(), signal.SIGTERM)
# self.hass.block_till_done()
# self.assertEqual(1, len(calls))
def test_pending_sheduler(self):
"""Add a coro to pending tasks."""
call_count = []