diff --git a/homeassistant/components/shell_command.py b/homeassistant/components/shell_command.py index 6aabdc8ddf7..ca33666d1f3 100644 --- a/homeassistant/components/shell_command.py +++ b/homeassistant/components/shell_command.py @@ -4,15 +4,17 @@ Exposes regular shell commands as services. For more details about this platform, please refer to the documentation at https://home-assistant.io/components/shell_command/ """ +import asyncio import logging -import subprocess import shlex import voluptuous as vol -from homeassistant.helpers import template from homeassistant.exceptions import TemplateError -import homeassistant.helpers.config_validation as cv +from homeassistant.core import ServiceCall +from homeassistant.helpers import config_validation as cv, template +from homeassistant.helpers.typing import ConfigType, HomeAssistantType + DOMAIN = 'shell_command' @@ -25,15 +27,17 @@ CONFIG_SCHEMA = vol.Schema({ }, extra=vol.ALLOW_EXTRA) -def setup(hass, config): +@asyncio.coroutine +def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool: """Set up the shell_command component.""" conf = config.get(DOMAIN, {}) cache = {} - def service_handler(call): + @asyncio.coroutine + def async_service_handler(service: ServiceCall) -> None: """Execute a shell command service.""" - cmd = conf[call.service] + cmd = conf[service.service] if cmd in cache: prog, args, args_compiled = cache[cmd] @@ -49,7 +53,7 @@ def setup(hass, config): if args_compiled: try: - rendered_args = args_compiled.render(call.data) + rendered_args = args_compiled.async_render(service.data) except TemplateError as ex: _LOGGER.exception("Error rendering command template: %s", ex) return @@ -58,19 +62,34 @@ def setup(hass, config): if rendered_args == args: # No template used. default behavior - shell = True - else: - # Template used. Break into list and use shell=False for security - cmd = [prog] + shlex.split(rendered_args) - shell = False - try: - subprocess.call(cmd, shell=shell, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL) - except subprocess.SubprocessError: - _LOGGER.exception("Error running command: %s", cmd) + # pylint: disable=no-member + create_process = asyncio.subprocess.create_subprocess_shell( + cmd, + loop=hass.loop, + stdin=None, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL) + else: + # Template used. Break into list and use create_subprocess_exec + # (which uses shell=False) for security + shlexed_cmd = [prog] + shlex.split(rendered_args) + + # pylint: disable=no-member + create_process = asyncio.subprocess.create_subprocess_exec( + *shlexed_cmd, + loop=hass.loop, + stdin=None, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL) + + process = yield from create_process + yield from process.communicate() + + if process.returncode != 0: + _LOGGER.exception("Error running command: `%s`, return code: %s", + cmd, process.returncode) for name in conf.keys(): - hass.services.register(DOMAIN, name, service_handler) + hass.services.async_register(DOMAIN, name, async_service_handler) return True diff --git a/tests/components/test_shell_command.py b/tests/components/test_shell_command.py index b75a95e23cd..3bdb6896394 100644 --- a/tests/components/test_shell_command.py +++ b/tests/components/test_shell_command.py @@ -1,9 +1,10 @@ """The tests for the Shell command component.""" +import asyncio import os import tempfile import unittest -from unittest.mock import patch -from subprocess import SubprocessError +from typing import Tuple +from unittest.mock import Mock, patch from homeassistant.setup import setup_component from homeassistant.components import shell_command @@ -11,12 +12,35 @@ from homeassistant.components import shell_command from tests.common import get_test_home_assistant +@asyncio.coroutine +def mock_process_creator(error: bool = False) -> asyncio.coroutine: + """Mock a coroutine that creates a process when yielded.""" + @asyncio.coroutine + def communicate() -> Tuple[bytes, bytes]: + """Mock a coroutine that runs a process when yielded. + + Returns: + a tuple of (stdout, stderr). + """ + return b"I am stdout", b"I am stderr" + + mock_process = Mock() + mock_process.communicate = communicate + mock_process.returncode = int(error) + return mock_process + + class TestShellCommand(unittest.TestCase): - """Test the Shell command component.""" + """Test the shell_command component.""" def setUp(self): # pylint: disable=invalid-name - """Setup things to be run when tests are started.""" + """Setup things to be run when tests are started. + + Also seems to require a child watcher attached to the loop when run + from pytest. + """ self.hass = get_test_home_assistant() + asyncio.get_child_watcher().attach_loop(self.hass.loop) def tearDown(self): # pylint: disable=invalid-name """Stop everything that was started.""" @@ -26,84 +50,101 @@ class TestShellCommand(unittest.TestCase): """Test if able to call a configured service.""" with tempfile.TemporaryDirectory() as tempdirname: path = os.path.join(tempdirname, 'called.txt') - assert setup_component(self.hass, shell_command.DOMAIN, { - shell_command.DOMAIN: { - 'test_service': "date > {}".format(path) - } - }) + assert setup_component( + self.hass, + shell_command.DOMAIN, { + shell_command.DOMAIN: { + 'test_service': "date > {}".format(path) + } + } + ) self.hass.services.call('shell_command', 'test_service', blocking=True) self.hass.block_till_done() - self.assertTrue(os.path.isfile(path)) def test_config_not_dict(self): - """Test if config is not a dict.""" - assert not setup_component(self.hass, shell_command.DOMAIN, { - shell_command.DOMAIN: ['some', 'weird', 'list'] - }) + """Test that setup fails if config is not a dict.""" + self.assertFalse( + setup_component(self.hass, shell_command.DOMAIN, { + shell_command.DOMAIN: ['some', 'weird', 'list'] + })) def test_config_not_valid_service_names(self): - """Test if config contains invalid service names.""" - assert not setup_component(self.hass, shell_command.DOMAIN, { - shell_command.DOMAIN: { - 'this is invalid because space': 'touch bla.txt' - } - }) + """Test that setup fails if config contains invalid service names.""" + self.assertFalse( + setup_component(self.hass, shell_command.DOMAIN, { + shell_command.DOMAIN: { + 'this is invalid because space': 'touch bla.txt' + } + })) - @patch('homeassistant.components.shell_command.subprocess.call') + @patch('homeassistant.components.shell_command.asyncio.subprocess' + '.create_subprocess_shell') def test_template_render_no_template(self, mock_call): """Ensure shell_commands without templates get rendered properly.""" - assert setup_component(self.hass, shell_command.DOMAIN, { - shell_command.DOMAIN: { - 'test_service': "ls /bin" - } - }) + mock_call.return_value = mock_process_creator(error=False) + + self.assertTrue( + setup_component( + self.hass, + shell_command.DOMAIN, { + shell_command.DOMAIN: { + 'test_service': "ls /bin" + } + })) self.hass.services.call('shell_command', 'test_service', blocking=True) + self.hass.block_till_done() cmd = mock_call.mock_calls[0][1][0] - shell = mock_call.mock_calls[0][2]['shell'] - assert 'ls /bin' == cmd - assert shell + self.assertEqual(1, mock_call.call_count) + self.assertEqual('ls /bin', cmd) - @patch('homeassistant.components.shell_command.subprocess.call') + @patch('homeassistant.components.shell_command.asyncio.subprocess' + '.create_subprocess_exec') def test_template_render(self, mock_call): - """Ensure shell_commands without templates get rendered properly.""" + """Ensure shell_commands with templates get rendered properly.""" self.hass.states.set('sensor.test_state', 'Works') - assert setup_component(self.hass, shell_command.DOMAIN, { - shell_command.DOMAIN: { - 'test_service': "ls /bin {{ states.sensor.test_state.state }}" - } - }) + self.assertTrue( + setup_component(self.hass, shell_command.DOMAIN, { + shell_command.DOMAIN: { + 'test_service': ("ls /bin {{ states.sensor" + ".test_state.state }}") + } + })) self.hass.services.call('shell_command', 'test_service', blocking=True) - cmd = mock_call.mock_calls[0][1][0] - shell = mock_call.mock_calls[0][2]['shell'] + self.hass.block_till_done() + cmd = mock_call.mock_calls[0][1] - assert ['ls', '/bin', 'Works'] == cmd - assert not shell + self.assertEqual(1, mock_call.call_count) + self.assertEqual(('ls', '/bin', 'Works'), cmd) - @patch('homeassistant.components.shell_command.subprocess.call', - side_effect=SubprocessError) + @patch('homeassistant.components.shell_command.asyncio.subprocess' + '.create_subprocess_shell') @patch('homeassistant.components.shell_command._LOGGER.error') - def test_subprocess_raising_error(self, mock_call, mock_error): - """Test subprocess.""" + def test_subprocess_error(self, mock_error, mock_call): + """Test subprocess that returns an error.""" + mock_call.return_value = mock_process_creator(error=True) with tempfile.TemporaryDirectory() as tempdirname: path = os.path.join(tempdirname, 'called.txt') - assert setup_component(self.hass, shell_command.DOMAIN, { - shell_command.DOMAIN: { - 'test_service': "touch {}".format(path) - } - }) + self.assertTrue( + setup_component(self.hass, shell_command.DOMAIN, { + shell_command.DOMAIN: { + 'test_service': "touch {}".format(path) + } + })) self.hass.services.call('shell_command', 'test_service', blocking=True) - self.assertFalse(os.path.isfile(path)) + self.hass.block_till_done() + self.assertEqual(1, mock_call.call_count) self.assertEqual(1, mock_error.call_count) + self.assertFalse(os.path.isfile(path))