Add connection validation on import for dsmr integration (#39664)

This commit is contained in:
Rob Bierbooms 2020-09-05 12:05:46 +02:00 committed by GitHub
parent 3565fec005
commit 8567fe94e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 283 additions and 10 deletions

View File

@ -1,15 +1,114 @@
"""Config flow for DSMR integration.""" """Config flow for DSMR integration."""
import asyncio
from functools import partial
import logging import logging
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from homeassistant import config_entries from async_timeout import timeout
from dsmr_parser import obis_references as obis_ref
from dsmr_parser.clients.protocol import create_dsmr_reader, create_tcp_dsmr_reader
import serial
from homeassistant import config_entries, core, exceptions
from homeassistant.const import CONF_HOST, CONF_PORT from homeassistant.const import CONF_HOST, CONF_PORT
from .const import DOMAIN # pylint:disable=unused-import from .const import ( # pylint:disable=unused-import
CONF_DSMR_VERSION,
CONF_SERIAL_ID,
CONF_SERIAL_ID_GAS,
DOMAIN,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
class DSMRConnection:
"""Test the connection to DSMR and receive telegram to read serial ids."""
def __init__(self, host, port, dsmr_version):
"""Initialize."""
self._host = host
self._port = port
self._dsmr_version = dsmr_version
self._telegram = {}
def equipment_identifier(self):
"""Equipment identifier."""
if obis_ref.EQUIPMENT_IDENTIFIER in self._telegram:
dsmr_object = self._telegram[obis_ref.EQUIPMENT_IDENTIFIER]
return getattr(dsmr_object, "value", None)
def equipment_identifier_gas(self):
"""Equipment identifier gas."""
if obis_ref.EQUIPMENT_IDENTIFIER_GAS in self._telegram:
dsmr_object = self._telegram[obis_ref.EQUIPMENT_IDENTIFIER_GAS]
return getattr(dsmr_object, "value", None)
async def validate_connect(self, hass: core.HomeAssistant) -> bool:
"""Test if we can validate connection with the device."""
def update_telegram(telegram):
self._telegram = telegram
transport.close()
if self._host is None:
reader_factory = partial(
create_dsmr_reader,
self._port,
self._dsmr_version,
update_telegram,
loop=hass.loop,
)
else:
reader_factory = partial(
create_tcp_dsmr_reader,
self._host,
self._port,
self._dsmr_version,
update_telegram,
loop=hass.loop,
)
try:
transport, protocol = await asyncio.create_task(reader_factory())
except (serial.serialutil.SerialException, OSError):
_LOGGER.exception("Error connecting to DSMR")
return False
if transport:
try:
async with timeout(30):
await protocol.wait_closed()
except asyncio.TimeoutError:
# Timeout (no data received), close transport and return True (if telegram is empty, will result in CannotCommunicate error)
transport.close()
await protocol.wait_closed()
return True
async def _validate_dsmr_connection(hass: core.HomeAssistant, data):
"""Validate the user input allows us to connect."""
conn = DSMRConnection(data.get(CONF_HOST), data[CONF_PORT], data[CONF_DSMR_VERSION])
if not await conn.validate_connect(hass):
raise CannotConnect
equipment_identifier = conn.equipment_identifier()
equipment_identifier_gas = conn.equipment_identifier_gas()
# Check only for equipment identifier in case no gas meter is connected
if equipment_identifier is None:
raise CannotCommunicate
info = {
CONF_SERIAL_ID: equipment_identifier,
CONF_SERIAL_ID_GAS: equipment_identifier_gas,
}
return info
class DSMRFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): class DSMRFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for DSMR.""" """Handle a config flow for DSMR."""
@ -55,9 +154,29 @@ class DSMRFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
if status is not None: if status is not None:
return status return status
try:
info = await _validate_dsmr_connection(self.hass, import_config)
except CannotConnect:
return self.async_abort(reason="cannot_connect")
except CannotCommunicate:
return self.async_abort(reason="cannot_communicate")
if host is not None: if host is not None:
name = f"{host}:{port}" name = f"{host}:{port}"
else: else:
name = port name = port
return self.async_create_entry(title=name, data=import_config) data = {**import_config, **info}
await self.async_set_unique_id(info[CONF_SERIAL_ID])
self._abort_if_unique_id_configured(data)
return self.async_create_entry(title=name, data=data)
class CannotConnect(exceptions.HomeAssistantError):
"""Error to indicate we cannot connect."""
class CannotCommunicate(exceptions.HomeAssistantError):
"""Error to indicate we cannot connect."""

View File

@ -8,6 +8,9 @@ CONF_DSMR_VERSION = "dsmr_version"
CONF_RECONNECT_INTERVAL = "reconnect_interval" CONF_RECONNECT_INTERVAL = "reconnect_interval"
CONF_PRECISION = "precision" CONF_PRECISION = "precision"
CONF_SERIAL_ID = "serial_id"
CONF_SERIAL_ID_GAS = "serial_id_gas"
DEFAULT_DSMR_VERSION = "2.2" DEFAULT_DSMR_VERSION = "2.2"
DEFAULT_PORT = "/dev/ttyUSB0" DEFAULT_PORT = "/dev/ttyUSB0"
DEFAULT_PRECISION = 3 DEFAULT_PRECISION = 3

View File

@ -1,12 +1,65 @@
"""Test the DSMR config flow.""" """Test the DSMR config flow."""
import asyncio
from itertools import chain, repeat
from dsmr_parser.clients.protocol import DSMRProtocol
from dsmr_parser.obis_references import EQUIPMENT_IDENTIFIER, EQUIPMENT_IDENTIFIER_GAS
from dsmr_parser.objects import CosemObject
import pytest
import serial
from homeassistant import config_entries, setup from homeassistant import config_entries, setup
from homeassistant.components.dsmr import DOMAIN from homeassistant.components.dsmr import DOMAIN
from tests.async_mock import patch from tests.async_mock import DEFAULT, AsyncMock, Mock, patch
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
SERIAL_DATA = {"serial_id": "12345678", "serial_id_gas": "123456789"}
async def test_import_usb(hass):
@pytest.fixture
def mock_connection_factory(monkeypatch):
"""Mock the create functions for serial and TCP Asyncio connections."""
transport = Mock(spec=asyncio.Transport)
protocol = Mock(spec=DSMRProtocol)
async def connection_factory(*args, **kwargs):
"""Return mocked out Asyncio classes."""
return (transport, protocol)
connection_factory = Mock(wraps=connection_factory)
# apply the mock to both connection factories
monkeypatch.setattr(
"homeassistant.components.dsmr.config_flow.create_dsmr_reader",
connection_factory,
)
monkeypatch.setattr(
"homeassistant.components.dsmr.config_flow.create_tcp_dsmr_reader",
connection_factory,
)
protocol.telegram = {
EQUIPMENT_IDENTIFIER: CosemObject([{"value": "12345678", "unit": ""}]),
EQUIPMENT_IDENTIFIER_GAS: CosemObject([{"value": "123456789", "unit": ""}]),
}
async def wait_closed():
if isinstance(connection_factory.call_args_list[0][0][2], str):
# TCP
telegram_callback = connection_factory.call_args_list[0][0][3]
else:
# Serial
telegram_callback = connection_factory.call_args_list[0][0][2]
telegram_callback(protocol.telegram)
protocol.wait_closed = wait_closed
return connection_factory, transport, protocol
async def test_import_usb(hass, mock_connection_factory):
"""Test we can import.""" """Test we can import."""
await setup.async_setup_component(hass, "persistent_notification", {}) await setup.async_setup_component(hass, "persistent_notification", {})
@ -26,10 +79,103 @@ async def test_import_usb(hass):
assert result["type"] == "create_entry" assert result["type"] == "create_entry"
assert result["title"] == "/dev/ttyUSB0" assert result["title"] == "/dev/ttyUSB0"
assert result["data"] == entry_data assert result["data"] == {**entry_data, **SERIAL_DATA}
async def test_import_network(hass): async def test_import_usb_failed_connection(hass, monkeypatch, mock_connection_factory):
"""Test we can import."""
(connection_factory, transport, protocol) = mock_connection_factory
await setup.async_setup_component(hass, "persistent_notification", {})
entry_data = {
"port": "/dev/ttyUSB0",
"dsmr_version": "2.2",
"precision": 4,
"reconnect_interval": 30,
}
# override the mock to have it fail the first time and succeed after
first_fail_connection_factory = AsyncMock(
return_value=(transport, protocol),
side_effect=chain([serial.serialutil.SerialException], repeat(DEFAULT)),
)
monkeypatch.setattr(
"homeassistant.components.dsmr.config_flow.create_dsmr_reader",
first_fail_connection_factory,
)
with patch("homeassistant.components.dsmr.async_setup_entry", return_value=True):
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_IMPORT},
data=entry_data,
)
assert result["type"] == "abort"
assert result["reason"] == "cannot_connect"
async def test_import_usb_no_data(hass, monkeypatch, mock_connection_factory):
"""Test we can import."""
(connection_factory, transport, protocol) = mock_connection_factory
await setup.async_setup_component(hass, "persistent_notification", {})
entry_data = {
"port": "/dev/ttyUSB0",
"dsmr_version": "2.2",
"precision": 4,
"reconnect_interval": 30,
}
# override the mock to have it fail the first time and succeed after
wait_closed = AsyncMock(
return_value=(transport, protocol),
side_effect=chain([asyncio.TimeoutError], repeat(DEFAULT)),
)
protocol.wait_closed = wait_closed
with patch("homeassistant.components.dsmr.async_setup_entry", return_value=True):
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_IMPORT},
data=entry_data,
)
assert result["type"] == "abort"
assert result["reason"] == "cannot_communicate"
async def test_import_usb_wrong_telegram(hass, mock_connection_factory):
"""Test we can import."""
(connection_factory, transport, protocol) = mock_connection_factory
await setup.async_setup_component(hass, "persistent_notification", {})
entry_data = {
"port": "/dev/ttyUSB0",
"dsmr_version": "2.2",
"precision": 4,
"reconnect_interval": 30,
}
protocol.telegram = {}
with patch("homeassistant.components.dsmr.async_setup_entry", return_value=True):
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_IMPORT},
data=entry_data,
)
assert result["type"] == "abort"
assert result["reason"] == "cannot_communicate"
async def test_import_network(hass, mock_connection_factory):
"""Test we can import from network.""" """Test we can import from network."""
await setup.async_setup_component(hass, "persistent_notification", {}) await setup.async_setup_component(hass, "persistent_notification", {})
@ -50,10 +196,10 @@ async def test_import_network(hass):
assert result["type"] == "create_entry" assert result["type"] == "create_entry"
assert result["title"] == "localhost:1234" assert result["title"] == "localhost:1234"
assert result["data"] == entry_data assert result["data"] == {**entry_data, **SERIAL_DATA}
async def test_import_update(hass): async def test_import_update(hass, mock_connection_factory):
"""Test we can import.""" """Test we can import."""
await setup.async_setup_component(hass, "persistent_notification", {}) await setup.async_setup_component(hass, "persistent_notification", {})

View File

@ -61,8 +61,13 @@ async def test_setup_platform(hass, mock_connection_factory):
"reconnect_interval": 30, "reconnect_interval": 30,
} }
serial_data = {"serial_id": "1234", "serial_id_gas": "5678"}
with patch("homeassistant.components.dsmr.async_setup", return_value=True), patch( with patch("homeassistant.components.dsmr.async_setup", return_value=True), patch(
"homeassistant.components.dsmr.async_setup_entry", return_value=True "homeassistant.components.dsmr.async_setup_entry", return_value=True
), patch(
"homeassistant.components.dsmr.config_flow._validate_dsmr_connection",
return_value=serial_data,
): ):
assert await async_setup_component( assert await async_setup_component(
hass, SENSOR_DOMAIN, {SENSOR_DOMAIN: entry_data} hass, SENSOR_DOMAIN, {SENSOR_DOMAIN: entry_data}
@ -79,7 +84,7 @@ async def test_setup_platform(hass, mock_connection_factory):
entry = conf_entries[0] entry = conf_entries[0]
assert entry.state == "loaded" assert entry.state == "loaded"
assert entry.data == entry_data assert entry.data == {**entry_data, **serial_data}
async def test_default_setup(hass, mock_connection_factory): async def test_default_setup(hass, mock_connection_factory):