mirror of
https://github.com/home-assistant/core.git
synced 2025-04-26 18:27:51 +00:00
Add connection validation on import for dsmr integration (#39664)
This commit is contained in:
parent
3565fec005
commit
8567fe94e1
@ -1,15 +1,114 @@
|
|||||||
"""Config flow for DSMR integration."""
|
"""Config flow for DSMR integration."""
|
||||||
|
import asyncio
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from homeassistant import config_entries
|
from async_timeout import timeout
|
||||||
|
from dsmr_parser import obis_references as obis_ref
|
||||||
|
from dsmr_parser.clients.protocol import create_dsmr_reader, create_tcp_dsmr_reader
|
||||||
|
import serial
|
||||||
|
|
||||||
|
from homeassistant import config_entries, core, exceptions
|
||||||
from homeassistant.const import CONF_HOST, CONF_PORT
|
from homeassistant.const import CONF_HOST, CONF_PORT
|
||||||
|
|
||||||
from .const import DOMAIN # pylint:disable=unused-import
|
from .const import ( # pylint:disable=unused-import
|
||||||
|
CONF_DSMR_VERSION,
|
||||||
|
CONF_SERIAL_ID,
|
||||||
|
CONF_SERIAL_ID_GAS,
|
||||||
|
DOMAIN,
|
||||||
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DSMRConnection:
|
||||||
|
"""Test the connection to DSMR and receive telegram to read serial ids."""
|
||||||
|
|
||||||
|
def __init__(self, host, port, dsmr_version):
|
||||||
|
"""Initialize."""
|
||||||
|
self._host = host
|
||||||
|
self._port = port
|
||||||
|
self._dsmr_version = dsmr_version
|
||||||
|
self._telegram = {}
|
||||||
|
|
||||||
|
def equipment_identifier(self):
|
||||||
|
"""Equipment identifier."""
|
||||||
|
if obis_ref.EQUIPMENT_IDENTIFIER in self._telegram:
|
||||||
|
dsmr_object = self._telegram[obis_ref.EQUIPMENT_IDENTIFIER]
|
||||||
|
return getattr(dsmr_object, "value", None)
|
||||||
|
|
||||||
|
def equipment_identifier_gas(self):
|
||||||
|
"""Equipment identifier gas."""
|
||||||
|
if obis_ref.EQUIPMENT_IDENTIFIER_GAS in self._telegram:
|
||||||
|
dsmr_object = self._telegram[obis_ref.EQUIPMENT_IDENTIFIER_GAS]
|
||||||
|
return getattr(dsmr_object, "value", None)
|
||||||
|
|
||||||
|
async def validate_connect(self, hass: core.HomeAssistant) -> bool:
|
||||||
|
"""Test if we can validate connection with the device."""
|
||||||
|
|
||||||
|
def update_telegram(telegram):
|
||||||
|
self._telegram = telegram
|
||||||
|
|
||||||
|
transport.close()
|
||||||
|
|
||||||
|
if self._host is None:
|
||||||
|
reader_factory = partial(
|
||||||
|
create_dsmr_reader,
|
||||||
|
self._port,
|
||||||
|
self._dsmr_version,
|
||||||
|
update_telegram,
|
||||||
|
loop=hass.loop,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
reader_factory = partial(
|
||||||
|
create_tcp_dsmr_reader,
|
||||||
|
self._host,
|
||||||
|
self._port,
|
||||||
|
self._dsmr_version,
|
||||||
|
update_telegram,
|
||||||
|
loop=hass.loop,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
transport, protocol = await asyncio.create_task(reader_factory())
|
||||||
|
except (serial.serialutil.SerialException, OSError):
|
||||||
|
_LOGGER.exception("Error connecting to DSMR")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if transport:
|
||||||
|
try:
|
||||||
|
async with timeout(30):
|
||||||
|
await protocol.wait_closed()
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Timeout (no data received), close transport and return True (if telegram is empty, will result in CannotCommunicate error)
|
||||||
|
transport.close()
|
||||||
|
await protocol.wait_closed()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def _validate_dsmr_connection(hass: core.HomeAssistant, data):
|
||||||
|
"""Validate the user input allows us to connect."""
|
||||||
|
conn = DSMRConnection(data.get(CONF_HOST), data[CONF_PORT], data[CONF_DSMR_VERSION])
|
||||||
|
|
||||||
|
if not await conn.validate_connect(hass):
|
||||||
|
raise CannotConnect
|
||||||
|
|
||||||
|
equipment_identifier = conn.equipment_identifier()
|
||||||
|
equipment_identifier_gas = conn.equipment_identifier_gas()
|
||||||
|
|
||||||
|
# Check only for equipment identifier in case no gas meter is connected
|
||||||
|
if equipment_identifier is None:
|
||||||
|
raise CannotCommunicate
|
||||||
|
|
||||||
|
info = {
|
||||||
|
CONF_SERIAL_ID: equipment_identifier,
|
||||||
|
CONF_SERIAL_ID_GAS: equipment_identifier_gas,
|
||||||
|
}
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
class DSMRFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
class DSMRFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
||||||
"""Handle a config flow for DSMR."""
|
"""Handle a config flow for DSMR."""
|
||||||
|
|
||||||
@ -55,9 +154,29 @@ class DSMRFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||||||
if status is not None:
|
if status is not None:
|
||||||
return status
|
return status
|
||||||
|
|
||||||
|
try:
|
||||||
|
info = await _validate_dsmr_connection(self.hass, import_config)
|
||||||
|
except CannotConnect:
|
||||||
|
return self.async_abort(reason="cannot_connect")
|
||||||
|
except CannotCommunicate:
|
||||||
|
return self.async_abort(reason="cannot_communicate")
|
||||||
|
|
||||||
if host is not None:
|
if host is not None:
|
||||||
name = f"{host}:{port}"
|
name = f"{host}:{port}"
|
||||||
else:
|
else:
|
||||||
name = port
|
name = port
|
||||||
|
|
||||||
return self.async_create_entry(title=name, data=import_config)
|
data = {**import_config, **info}
|
||||||
|
|
||||||
|
await self.async_set_unique_id(info[CONF_SERIAL_ID])
|
||||||
|
self._abort_if_unique_id_configured(data)
|
||||||
|
|
||||||
|
return self.async_create_entry(title=name, data=data)
|
||||||
|
|
||||||
|
|
||||||
|
class CannotConnect(exceptions.HomeAssistantError):
|
||||||
|
"""Error to indicate we cannot connect."""
|
||||||
|
|
||||||
|
|
||||||
|
class CannotCommunicate(exceptions.HomeAssistantError):
|
||||||
|
"""Error to indicate we cannot connect."""
|
||||||
|
@ -8,6 +8,9 @@ CONF_DSMR_VERSION = "dsmr_version"
|
|||||||
CONF_RECONNECT_INTERVAL = "reconnect_interval"
|
CONF_RECONNECT_INTERVAL = "reconnect_interval"
|
||||||
CONF_PRECISION = "precision"
|
CONF_PRECISION = "precision"
|
||||||
|
|
||||||
|
CONF_SERIAL_ID = "serial_id"
|
||||||
|
CONF_SERIAL_ID_GAS = "serial_id_gas"
|
||||||
|
|
||||||
DEFAULT_DSMR_VERSION = "2.2"
|
DEFAULT_DSMR_VERSION = "2.2"
|
||||||
DEFAULT_PORT = "/dev/ttyUSB0"
|
DEFAULT_PORT = "/dev/ttyUSB0"
|
||||||
DEFAULT_PRECISION = 3
|
DEFAULT_PRECISION = 3
|
||||||
|
@ -1,12 +1,65 @@
|
|||||||
"""Test the DSMR config flow."""
|
"""Test the DSMR config flow."""
|
||||||
|
import asyncio
|
||||||
|
from itertools import chain, repeat
|
||||||
|
|
||||||
|
from dsmr_parser.clients.protocol import DSMRProtocol
|
||||||
|
from dsmr_parser.obis_references import EQUIPMENT_IDENTIFIER, EQUIPMENT_IDENTIFIER_GAS
|
||||||
|
from dsmr_parser.objects import CosemObject
|
||||||
|
import pytest
|
||||||
|
import serial
|
||||||
|
|
||||||
from homeassistant import config_entries, setup
|
from homeassistant import config_entries, setup
|
||||||
from homeassistant.components.dsmr import DOMAIN
|
from homeassistant.components.dsmr import DOMAIN
|
||||||
|
|
||||||
from tests.async_mock import patch
|
from tests.async_mock import DEFAULT, AsyncMock, Mock, patch
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
SERIAL_DATA = {"serial_id": "12345678", "serial_id_gas": "123456789"}
|
||||||
|
|
||||||
async def test_import_usb(hass):
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_connection_factory(monkeypatch):
|
||||||
|
"""Mock the create functions for serial and TCP Asyncio connections."""
|
||||||
|
transport = Mock(spec=asyncio.Transport)
|
||||||
|
protocol = Mock(spec=DSMRProtocol)
|
||||||
|
|
||||||
|
async def connection_factory(*args, **kwargs):
|
||||||
|
"""Return mocked out Asyncio classes."""
|
||||||
|
return (transport, protocol)
|
||||||
|
|
||||||
|
connection_factory = Mock(wraps=connection_factory)
|
||||||
|
|
||||||
|
# apply the mock to both connection factories
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"homeassistant.components.dsmr.config_flow.create_dsmr_reader",
|
||||||
|
connection_factory,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"homeassistant.components.dsmr.config_flow.create_tcp_dsmr_reader",
|
||||||
|
connection_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
protocol.telegram = {
|
||||||
|
EQUIPMENT_IDENTIFIER: CosemObject([{"value": "12345678", "unit": ""}]),
|
||||||
|
EQUIPMENT_IDENTIFIER_GAS: CosemObject([{"value": "123456789", "unit": ""}]),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def wait_closed():
|
||||||
|
if isinstance(connection_factory.call_args_list[0][0][2], str):
|
||||||
|
# TCP
|
||||||
|
telegram_callback = connection_factory.call_args_list[0][0][3]
|
||||||
|
else:
|
||||||
|
# Serial
|
||||||
|
telegram_callback = connection_factory.call_args_list[0][0][2]
|
||||||
|
|
||||||
|
telegram_callback(protocol.telegram)
|
||||||
|
|
||||||
|
protocol.wait_closed = wait_closed
|
||||||
|
|
||||||
|
return connection_factory, transport, protocol
|
||||||
|
|
||||||
|
|
||||||
|
async def test_import_usb(hass, mock_connection_factory):
|
||||||
"""Test we can import."""
|
"""Test we can import."""
|
||||||
await setup.async_setup_component(hass, "persistent_notification", {})
|
await setup.async_setup_component(hass, "persistent_notification", {})
|
||||||
|
|
||||||
@ -26,10 +79,103 @@ async def test_import_usb(hass):
|
|||||||
|
|
||||||
assert result["type"] == "create_entry"
|
assert result["type"] == "create_entry"
|
||||||
assert result["title"] == "/dev/ttyUSB0"
|
assert result["title"] == "/dev/ttyUSB0"
|
||||||
assert result["data"] == entry_data
|
assert result["data"] == {**entry_data, **SERIAL_DATA}
|
||||||
|
|
||||||
|
|
||||||
async def test_import_network(hass):
|
async def test_import_usb_failed_connection(hass, monkeypatch, mock_connection_factory):
|
||||||
|
"""Test we can import."""
|
||||||
|
(connection_factory, transport, protocol) = mock_connection_factory
|
||||||
|
|
||||||
|
await setup.async_setup_component(hass, "persistent_notification", {})
|
||||||
|
|
||||||
|
entry_data = {
|
||||||
|
"port": "/dev/ttyUSB0",
|
||||||
|
"dsmr_version": "2.2",
|
||||||
|
"precision": 4,
|
||||||
|
"reconnect_interval": 30,
|
||||||
|
}
|
||||||
|
|
||||||
|
# override the mock to have it fail the first time and succeed after
|
||||||
|
first_fail_connection_factory = AsyncMock(
|
||||||
|
return_value=(transport, protocol),
|
||||||
|
side_effect=chain([serial.serialutil.SerialException], repeat(DEFAULT)),
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"homeassistant.components.dsmr.config_flow.create_dsmr_reader",
|
||||||
|
first_fail_connection_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("homeassistant.components.dsmr.async_setup_entry", return_value=True):
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN,
|
||||||
|
context={"source": config_entries.SOURCE_IMPORT},
|
||||||
|
data=entry_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] == "abort"
|
||||||
|
assert result["reason"] == "cannot_connect"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_import_usb_no_data(hass, monkeypatch, mock_connection_factory):
|
||||||
|
"""Test we can import."""
|
||||||
|
(connection_factory, transport, protocol) = mock_connection_factory
|
||||||
|
|
||||||
|
await setup.async_setup_component(hass, "persistent_notification", {})
|
||||||
|
|
||||||
|
entry_data = {
|
||||||
|
"port": "/dev/ttyUSB0",
|
||||||
|
"dsmr_version": "2.2",
|
||||||
|
"precision": 4,
|
||||||
|
"reconnect_interval": 30,
|
||||||
|
}
|
||||||
|
|
||||||
|
# override the mock to have it fail the first time and succeed after
|
||||||
|
wait_closed = AsyncMock(
|
||||||
|
return_value=(transport, protocol),
|
||||||
|
side_effect=chain([asyncio.TimeoutError], repeat(DEFAULT)),
|
||||||
|
)
|
||||||
|
|
||||||
|
protocol.wait_closed = wait_closed
|
||||||
|
|
||||||
|
with patch("homeassistant.components.dsmr.async_setup_entry", return_value=True):
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN,
|
||||||
|
context={"source": config_entries.SOURCE_IMPORT},
|
||||||
|
data=entry_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] == "abort"
|
||||||
|
assert result["reason"] == "cannot_communicate"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_import_usb_wrong_telegram(hass, mock_connection_factory):
|
||||||
|
"""Test we can import."""
|
||||||
|
(connection_factory, transport, protocol) = mock_connection_factory
|
||||||
|
|
||||||
|
await setup.async_setup_component(hass, "persistent_notification", {})
|
||||||
|
|
||||||
|
entry_data = {
|
||||||
|
"port": "/dev/ttyUSB0",
|
||||||
|
"dsmr_version": "2.2",
|
||||||
|
"precision": 4,
|
||||||
|
"reconnect_interval": 30,
|
||||||
|
}
|
||||||
|
|
||||||
|
protocol.telegram = {}
|
||||||
|
|
||||||
|
with patch("homeassistant.components.dsmr.async_setup_entry", return_value=True):
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN,
|
||||||
|
context={"source": config_entries.SOURCE_IMPORT},
|
||||||
|
data=entry_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] == "abort"
|
||||||
|
assert result["reason"] == "cannot_communicate"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_import_network(hass, mock_connection_factory):
|
||||||
"""Test we can import from network."""
|
"""Test we can import from network."""
|
||||||
await setup.async_setup_component(hass, "persistent_notification", {})
|
await setup.async_setup_component(hass, "persistent_notification", {})
|
||||||
|
|
||||||
@ -50,10 +196,10 @@ async def test_import_network(hass):
|
|||||||
|
|
||||||
assert result["type"] == "create_entry"
|
assert result["type"] == "create_entry"
|
||||||
assert result["title"] == "localhost:1234"
|
assert result["title"] == "localhost:1234"
|
||||||
assert result["data"] == entry_data
|
assert result["data"] == {**entry_data, **SERIAL_DATA}
|
||||||
|
|
||||||
|
|
||||||
async def test_import_update(hass):
|
async def test_import_update(hass, mock_connection_factory):
|
||||||
"""Test we can import."""
|
"""Test we can import."""
|
||||||
await setup.async_setup_component(hass, "persistent_notification", {})
|
await setup.async_setup_component(hass, "persistent_notification", {})
|
||||||
|
|
||||||
|
@ -61,8 +61,13 @@ async def test_setup_platform(hass, mock_connection_factory):
|
|||||||
"reconnect_interval": 30,
|
"reconnect_interval": 30,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
serial_data = {"serial_id": "1234", "serial_id_gas": "5678"}
|
||||||
|
|
||||||
with patch("homeassistant.components.dsmr.async_setup", return_value=True), patch(
|
with patch("homeassistant.components.dsmr.async_setup", return_value=True), patch(
|
||||||
"homeassistant.components.dsmr.async_setup_entry", return_value=True
|
"homeassistant.components.dsmr.async_setup_entry", return_value=True
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.dsmr.config_flow._validate_dsmr_connection",
|
||||||
|
return_value=serial_data,
|
||||||
):
|
):
|
||||||
assert await async_setup_component(
|
assert await async_setup_component(
|
||||||
hass, SENSOR_DOMAIN, {SENSOR_DOMAIN: entry_data}
|
hass, SENSOR_DOMAIN, {SENSOR_DOMAIN: entry_data}
|
||||||
@ -79,7 +84,7 @@ async def test_setup_platform(hass, mock_connection_factory):
|
|||||||
entry = conf_entries[0]
|
entry = conf_entries[0]
|
||||||
|
|
||||||
assert entry.state == "loaded"
|
assert entry.state == "loaded"
|
||||||
assert entry.data == entry_data
|
assert entry.data == {**entry_data, **serial_data}
|
||||||
|
|
||||||
|
|
||||||
async def test_default_setup(hass, mock_connection_factory):
|
async def test_default_setup(hass, mock_connection_factory):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user