diff --git a/tests/common.py b/tests/common.py index dec5bdac4a4..5915a45a84c 100644 --- a/tests/common.py +++ b/tests/common.py @@ -12,7 +12,7 @@ from contextlib import contextmanager from aiohttp import web from homeassistant import core as ha, loader -from homeassistant.setup import setup_component +from homeassistant.setup import setup_component, async_setup_component from homeassistant.config import async_process_component_config from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.entity import ToggleEntity @@ -45,12 +45,27 @@ def threadsafe_callback_factory(func): def threadsafe(*args, **kwargs): """Call func threadsafe.""" hass = args[0] - run_callback_threadsafe( + return run_callback_threadsafe( hass.loop, ft.partial(func, *args, **kwargs)).result() return threadsafe +def threadsafe_coroutine_factory(func): + """Create threadsafe functions out of coroutine. + + Callback needs to have `hass` as first argument. + """ + @ft.wraps(func) + def threadsafe(*args, **kwargs): + """Call func threadsafe.""" + hass = args[0] + return run_coroutine_threadsafe( + func(*args, **kwargs), hass.loop).result() + + return threadsafe + + def get_test_config_dir(*add_path): """Return a path to a test config dir.""" return os.path.join(os.path.dirname(__file__), 'testing_config', *add_path) @@ -258,11 +273,12 @@ def mock_http_component_app(hass, api_password=None): return app -def mock_mqtt_component(hass): +@asyncio.coroutine +def async_mock_mqtt_component(hass): """Mock the MQTT component.""" with patch('homeassistant.components.mqtt.MQTT') as mock_mqtt: mock_mqtt().async_connect.return_value = mock_coro(True) - setup_component(hass, mqtt.DOMAIN, { + yield from async_setup_component(hass, mqtt.DOMAIN, { mqtt.DOMAIN: { mqtt.CONF_BROKER: 'mock-broker', } @@ -270,6 +286,9 @@ def mock_mqtt_component(hass): return mock_mqtt +mock_mqtt_component = threadsafe_coroutine_factory(async_mock_mqtt_component) + + @ha.callback def mock_component(hass, component): """Mock a component is setup.""" diff --git a/tests/components/camera/test_mqtt.py b/tests/components/camera/test_mqtt.py index 802d29a510a..20d15efd982 100644 --- a/tests/components/camera/test_mqtt.py +++ b/tests/components/camera/test_mqtt.py @@ -1,47 +1,31 @@ """The tests for mqtt camera component.""" import asyncio -import unittest from homeassistant.setup import async_setup_component from tests.common import ( - get_test_home_assistant, mock_mqtt_component, get_test_instance_port) - -import requests - -SERVER_PORT = get_test_instance_port() -HTTP_BASE_URL = 'http://127.0.0.1:{}'.format(SERVER_PORT) + async_mock_mqtt_component, async_fire_mqtt_message) -class TestComponentsMQTTCamera(unittest.TestCase): - """Test MQTT camera platform.""" +@asyncio.coroutine +def test_run_camera_setup(hass, test_client): + """Test that it fetches the given payload.""" + topic = 'test/camera' + yield from async_mock_mqtt_component(hass) + yield from async_setup_component(hass, 'camera', { + 'camera': { + 'platform': 'mqtt', + 'topic': topic, + 'name': 'Test Camera', + }}) - def setUp(self): # pylint: disable=invalid-name - """Setup things to be run when tests are started.""" - self.hass = get_test_home_assistant() - self.mock_mqtt = mock_mqtt_component(self.hass) + url = hass.states.get('camera.test_camera').attributes['entity_picture'] - def tearDown(self): # pylint: disable=invalid-name - """Stop everything that was started.""" - self.hass.stop() + async_fire_mqtt_message(hass, topic, 'beer') + yield from hass.async_block_till_done() - @asyncio.coroutine - def test_run_camera_setup(self): - """Test that it fetches the given payload.""" - topic = 'test/camera' - yield from async_setup_component(self.hass, 'camera', { - 'camera': { - 'platform': 'mqtt', - 'topic': topic, - 'name': 'Test Camera', - }}) - - self.mock_mqtt.publish(self.hass, topic, 0xFFD8FF) - yield from self.hass.async_block_till_done() - - resp = requests.get(HTTP_BASE_URL + - '/api/camera_proxy/camera.test_camera') - - assert resp.status_code == 200 - body = yield from resp.text - assert body == '16767231' + client = yield from test_client(hass.http.app) + resp = yield from client.get(url) + assert resp.status == 200 + body = yield from resp.text() + assert body == 'beer'