This commit is contained in:
Johann Kellerman 2016-09-14 03:17:51 +02:00 committed by Paulus Schoutsen
parent 165362da0c
commit afc527ea55
2 changed files with 16 additions and 21 deletions

View File

@ -98,8 +98,7 @@ def run_information(point_in_time: Optional[datetime]=None):
def setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Setup the recorder."""
# pylint: disable=global-statement
global _INSTANCE
global _INSTANCE # pylint: disable=global-statement
if _INSTANCE is not None:
_LOGGER.error('Only a single instance allowed.')
@ -164,7 +163,6 @@ class Recorder(threading.Thread):
self.hass = hass
self.purge_days = purge_days
self.queue = queue.Queue() # type: Any
self.quit_object = object()
self.recording_start = dt_util.utcnow()
self.db_url = uri
self.db_ready = threading.Event()
@ -205,12 +203,9 @@ class Recorder(threading.Thread):
while True:
event = self.queue.get()
if event == self.quit_object:
if event is None:
self._close_run()
self._close_connection()
# pylint: disable=global-statement
global _INSTANCE
_INSTANCE = None
self.queue.task_done()
return
@ -238,8 +233,11 @@ class Recorder(threading.Thread):
def shutdown(self, event):
"""Tell the recorder to shut down."""
self.queue.put(self.quit_object)
self.queue.join()
global _INSTANCE # pylint: disable=global-statement
_INSTANCE = None
self.queue.put(None)
self.join()
def block_till_done(self):
"""Block till all events processed."""
@ -251,8 +249,7 @@ class Recorder(threading.Thread):
def _setup_connection(self):
"""Ensure database is ready to fly."""
# pylint: disable=global-statement
global Session
global Session # pylint: disable=global-statement
import homeassistant.components.recorder.models as models
from sqlalchemy import create_engine
@ -275,8 +272,7 @@ class Recorder(threading.Thread):
def _close_connection(self):
"""Close the connection."""
# pylint: disable=global-statement
global Session
global Session # pylint: disable=global-statement
self.engine.dispose()
self.engine = None
Session = None

View File

@ -3,11 +3,10 @@
import json
from datetime import datetime, timedelta
import unittest
from unittest.mock import patch
from homeassistant.const import MATCH_ALL
from homeassistant.components import recorder
from homeassistant.bootstrap import _setup_component
from tests.common import get_test_home_assistant
@ -17,19 +16,19 @@ class TestRecorder(unittest.TestCase):
def setUp(self): # pylint: disable=invalid-name
"""Setup things to be run when tests are started."""
self.hass = get_test_home_assistant()
db_uri = 'sqlite://'
with patch('homeassistant.core.Config.path', return_value=db_uri):
recorder.setup(self.hass, config={
"recorder": {
"db_url": db_uri}})
db_uri = 'sqlite://' # In memory DB
_setup_component(self.hass, recorder.DOMAIN, {
recorder.DOMAIN: {recorder.CONF_DB_URL: db_uri}})
self.hass.start()
recorder._INSTANCE.block_till_db_ready()
recorder._verify_instance()
self.session = recorder.Session()
recorder._INSTANCE.block_till_done()
def tearDown(self): # pylint: disable=invalid-name
"""Stop everything that was started."""
recorder._INSTANCE.shutdown(None)
self.hass.stop()
assert recorder._INSTANCE is None
def _add_test_states(self):
"""Add multiple states to the db for testing."""