diff --git a/homeassistant/components/tcp/sensor.py b/homeassistant/components/tcp/sensor.py index 54cf4d120f1..ff436f8ecaf 100644 --- a/homeassistant/components/tcp/sensor.py +++ b/homeassistant/components/tcp/sensor.py @@ -2,6 +2,7 @@ import logging import select import socket +import ssl import voluptuous as vol @@ -11,9 +12,11 @@ from homeassistant.const import ( CONF_NAME, CONF_PAYLOAD, CONF_PORT, + CONF_SSL, CONF_TIMEOUT, CONF_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE, + CONF_VERIFY_SSL, ) from homeassistant.exceptions import TemplateError import homeassistant.helpers.config_validation as cv @@ -26,6 +29,8 @@ CONF_VALUE_ON = "value_on" DEFAULT_BUFFER_SIZE = 1024 DEFAULT_NAME = "TCP Sensor" DEFAULT_TIMEOUT = 10 +DEFAULT_SSL = False +DEFAULT_VERIFY_SSL = True PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( { @@ -38,6 +43,8 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string, vol.Optional(CONF_VALUE_ON): cv.string, vol.Optional(CONF_VALUE_TEMPLATE): cv.template, + vol.Optional(CONF_SSL, default=DEFAULT_SSL): cv.boolean, + vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean, } ) @@ -71,6 +78,15 @@ class TcpSensor(SensorEntity): CONF_VALUE_ON: config.get(CONF_VALUE_ON), CONF_BUFFER_SIZE: config.get(CONF_BUFFER_SIZE), } + + if config[CONF_SSL]: + self._ssl_context = ssl.create_default_context() + if not config[CONF_VERIFY_SSL]: + self._ssl_context.check_hostname = False + self._ssl_context.verify_mode = ssl.CERT_NONE + else: + self._ssl_context = None + self._state = None self.update() @@ -104,6 +120,11 @@ class TcpSensor(SensorEntity): ) return + if self._ssl_context is not None: + sock = self._ssl_context.wrap_socket( + sock, server_hostname=self._config[CONF_HOST] + ) + try: sock.send(self._config[CONF_PAYLOAD].encode()) except OSError as err: diff --git a/tests/components/tcp/test_sensor.py b/tests/components/tcp/test_sensor.py index b1efef305bf..48b5703c204 100644 --- a/tests/components/tcp/test_sensor.py +++ b/tests/components/tcp/test_sensor.py @@ -57,6 +57,18 @@ def mock_select_fixture(): yield mock_select +@pytest.fixture(name="mock_ssl_context") +def mock_ssl_context_fixture(): + """Mock select.""" + with patch( + "homeassistant.components.tcp.sensor.ssl.create_default_context", + ) as mock_ssl_context: + mock_ssl_context.return_value.wrap_socket.return_value.recv.return_value = ( + socket_test_value + "_ssl" + ).encode() + yield mock_ssl_context + + async def test_setup_platform_valid_config(hass, mock_socket): """Check a valid configuration and call add_entities with sensor.""" with assert_setup_component(1, "sensor"): @@ -159,3 +171,66 @@ async def test_update_returns_if_template_render_fails(hass, mock_socket): assert state assert state.state == "unknown" + + +async def test_ssl_state(hass, mock_socket, mock_select, mock_ssl_context): + """Return the contents of _state, updated over SSL.""" + config = copy(SENSOR_TEST_CONFIG) + config[tcp.CONF_SSL] = "on" + + assert await async_setup_component(hass, "sensor", {"sensor": config}) + await hass.async_block_till_done() + + state = hass.states.get(TEST_ENTITY) + + assert state + assert state.state == "test_value_ssl" + assert mock_socket.connect.called + assert mock_socket.connect.call_args == call( + (SENSOR_TEST_CONFIG["host"], SENSOR_TEST_CONFIG["port"]) + ) + assert not mock_socket.send.called + assert mock_ssl_context.called + assert mock_ssl_context.return_value.check_hostname + mock_ssl_socket = mock_ssl_context.return_value.wrap_socket.return_value + assert mock_ssl_socket.send.called + assert mock_ssl_socket.send.call_args == call( + SENSOR_TEST_CONFIG["payload"].encode() + ) + assert mock_select.call_args == call( + [mock_ssl_socket], [], [], SENSOR_TEST_CONFIG[tcp.CONF_TIMEOUT] + ) + assert mock_ssl_socket.recv.called + assert mock_ssl_socket.recv.call_args == call(SENSOR_TEST_CONFIG["buffer_size"]) + + +async def test_ssl_state_verify_off(hass, mock_socket, mock_select, mock_ssl_context): + """Return the contents of _state, updated over SSL (verify_ssl disabled).""" + config = copy(SENSOR_TEST_CONFIG) + config[tcp.CONF_SSL] = "on" + config[tcp.CONF_VERIFY_SSL] = "off" + + assert await async_setup_component(hass, "sensor", {"sensor": config}) + await hass.async_block_till_done() + + state = hass.states.get(TEST_ENTITY) + + assert state + assert state.state == "test_value_ssl" + assert mock_socket.connect.called + assert mock_socket.connect.call_args == call( + (SENSOR_TEST_CONFIG["host"], SENSOR_TEST_CONFIG["port"]) + ) + assert not mock_socket.send.called + assert mock_ssl_context.called + assert not mock_ssl_context.return_value.check_hostname + mock_ssl_socket = mock_ssl_context.return_value.wrap_socket.return_value + assert mock_ssl_socket.send.called + assert mock_ssl_socket.send.call_args == call( + SENSOR_TEST_CONFIG["payload"].encode() + ) + assert mock_select.call_args == call( + [mock_ssl_socket], [], [], SENSOR_TEST_CONFIG[tcp.CONF_TIMEOUT] + ) + assert mock_ssl_socket.recv.called + assert mock_ssl_socket.recv.call_args == call(SENSOR_TEST_CONFIG["buffer_size"])