This commit is contained in:
Johann Kellerman 2016-09-14 03:17:51 +02:00 committed by Paulus Schoutsen
parent 165362da0c
commit afc527ea55
2 changed files with 16 additions and 21 deletions

View File

@ -98,8 +98,7 @@ def run_information(point_in_time: Optional[datetime]=None):
def setup(hass: HomeAssistant, config: ConfigType) -> bool: def setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Setup the recorder.""" """Setup the recorder."""
# pylint: disable=global-statement global _INSTANCE # pylint: disable=global-statement
global _INSTANCE
if _INSTANCE is not None: if _INSTANCE is not None:
_LOGGER.error('Only a single instance allowed.') _LOGGER.error('Only a single instance allowed.')
@ -164,7 +163,6 @@ class Recorder(threading.Thread):
self.hass = hass self.hass = hass
self.purge_days = purge_days self.purge_days = purge_days
self.queue = queue.Queue() # type: Any self.queue = queue.Queue() # type: Any
self.quit_object = object()
self.recording_start = dt_util.utcnow() self.recording_start = dt_util.utcnow()
self.db_url = uri self.db_url = uri
self.db_ready = threading.Event() self.db_ready = threading.Event()
@ -205,12 +203,9 @@ class Recorder(threading.Thread):
while True: while True:
event = self.queue.get() event = self.queue.get()
if event == self.quit_object: if event is None:
self._close_run() self._close_run()
self._close_connection() self._close_connection()
# pylint: disable=global-statement
global _INSTANCE
_INSTANCE = None
self.queue.task_done() self.queue.task_done()
return return
@ -238,8 +233,11 @@ class Recorder(threading.Thread):
def shutdown(self, event): def shutdown(self, event):
"""Tell the recorder to shut down.""" """Tell the recorder to shut down."""
self.queue.put(self.quit_object) global _INSTANCE # pylint: disable=global-statement
self.queue.join() _INSTANCE = None
self.queue.put(None)
self.join()
def block_till_done(self): def block_till_done(self):
"""Block till all events processed.""" """Block till all events processed."""
@ -251,8 +249,7 @@ class Recorder(threading.Thread):
def _setup_connection(self): def _setup_connection(self):
"""Ensure database is ready to fly.""" """Ensure database is ready to fly."""
# pylint: disable=global-statement global Session # pylint: disable=global-statement
global Session
import homeassistant.components.recorder.models as models import homeassistant.components.recorder.models as models
from sqlalchemy import create_engine from sqlalchemy import create_engine
@ -275,8 +272,7 @@ class Recorder(threading.Thread):
def _close_connection(self): def _close_connection(self):
"""Close the connection.""" """Close the connection."""
# pylint: disable=global-statement global Session # pylint: disable=global-statement
global Session
self.engine.dispose() self.engine.dispose()
self.engine = None self.engine = None
Session = None Session = None

View File

@ -3,11 +3,10 @@
import json import json
from datetime import datetime, timedelta from datetime import datetime, timedelta
import unittest import unittest
from unittest.mock import patch
from homeassistant.const import MATCH_ALL from homeassistant.const import MATCH_ALL
from homeassistant.components import recorder from homeassistant.components import recorder
from homeassistant.bootstrap import _setup_component
from tests.common import get_test_home_assistant from tests.common import get_test_home_assistant
@ -17,19 +16,19 @@ class TestRecorder(unittest.TestCase):
def setUp(self): # pylint: disable=invalid-name def setUp(self): # pylint: disable=invalid-name
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant() self.hass = get_test_home_assistant()
db_uri = 'sqlite://' db_uri = 'sqlite://' # In memory DB
with patch('homeassistant.core.Config.path', return_value=db_uri): _setup_component(self.hass, recorder.DOMAIN, {
recorder.setup(self.hass, config={ recorder.DOMAIN: {recorder.CONF_DB_URL: db_uri}})
"recorder": {
"db_url": db_uri}})
self.hass.start() self.hass.start()
recorder._INSTANCE.block_till_db_ready() recorder._verify_instance()
self.session = recorder.Session() self.session = recorder.Session()
recorder._INSTANCE.block_till_done() recorder._INSTANCE.block_till_done()
def tearDown(self): # pylint: disable=invalid-name def tearDown(self): # pylint: disable=invalid-name
"""Stop everything that was started.""" """Stop everything that was started."""
recorder._INSTANCE.shutdown(None)
self.hass.stop() self.hass.stop()
assert recorder._INSTANCE is None
def _add_test_states(self): def _add_test_states(self):
"""Add multiple states to the db for testing.""" """Add multiple states to the db for testing."""