Simplify DATA_TYPE -> struct conversion. (#53805)

This commit is contained in:
jan iversen 2021-07-31 23:17:23 +02:00 committed by GitHub
parent f1f293de02
commit 3d52bfc8f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 71 additions and 68 deletions

View File

@ -1,5 +1,4 @@
"""Constants used in modbus integration.""" """Constants used in modbus integration."""
from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN
from homeassistant.components.climate.const import DOMAIN as CLIMATE_DOMAIN from homeassistant.components.climate.const import DOMAIN as CLIMATE_DOMAIN
from homeassistant.components.cover import DOMAIN as COVER_DOMAIN from homeassistant.components.cover import DOMAIN as COVER_DOMAIN
@ -110,18 +109,8 @@ DEFAULT_HUB = "modbus_hub"
DEFAULT_SCAN_INTERVAL = 15 # seconds DEFAULT_SCAN_INTERVAL = 15 # seconds
DEFAULT_SLAVE = 1 DEFAULT_SLAVE = 1
DEFAULT_STRUCTURE_PREFIX = ">f" DEFAULT_STRUCTURE_PREFIX = ">f"
DEFAULT_STRUCT_FORMAT = {
DATA_TYPE_INT16: ["h", 1],
DATA_TYPE_INT32: ["i", 2],
DATA_TYPE_INT64: ["q", 4],
DATA_TYPE_UINT16: ["H", 1],
DATA_TYPE_UINT32: ["I", 2],
DATA_TYPE_UINT64: ["Q", 4],
DATA_TYPE_FLOAT16: ["e", 1],
DATA_TYPE_FLOAT32: ["f", 2],
DATA_TYPE_FLOAT64: ["d", 4],
DATA_TYPE_STRING: ["s", 1],
}
DEFAULT_TEMP_UNIT = "C" DEFAULT_TEMP_UNIT = "C"
MODBUS_DOMAIN = "modbus" MODBUS_DOMAIN = "modbus"

View File

@ -1,6 +1,6 @@
"""Support for Modbus.""" """Support for Modbus."""
import asyncio import asyncio
from copy import deepcopy from collections import namedtuple
import logging import logging
from pymodbus.client.sync import ModbusSerialClient, ModbusTcpClient, ModbusUdpClient from pymodbus.client.sync import ModbusSerialClient, ModbusTcpClient, ModbusUdpClient
@ -54,54 +54,52 @@ from .const import (
SERVICE_WRITE_REGISTER, SERVICE_WRITE_REGISTER,
) )
ENTRY_FUNC = "func"
ENTRY_ATTR = "attr"
ENTRY_NAME = "name"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
PYMODBUS_CALL = { ConfEntry = namedtuple("ConfEntry", "call_type attr func_name")
CALL_TYPE_COIL: { RunEntry = namedtuple("RunEntry", "attr func")
ENTRY_ATTR: "bits", PYMODBUS_CALL = [
ENTRY_NAME: "read_coils", ConfEntry(
ENTRY_FUNC: None, CALL_TYPE_COIL,
}, "bits",
CALL_TYPE_DISCRETE: { "read_coils",
ENTRY_ATTR: "bits", ),
ENTRY_NAME: "read_discrete_inputs", ConfEntry(
ENTRY_FUNC: None, CALL_TYPE_DISCRETE,
}, "bits",
CALL_TYPE_REGISTER_HOLDING: { "read_discrete_inputs",
ENTRY_ATTR: "registers", ),
ENTRY_NAME: "read_holding_registers", ConfEntry(
ENTRY_FUNC: None, CALL_TYPE_REGISTER_HOLDING,
}, "registers",
CALL_TYPE_REGISTER_INPUT: { "read_holding_registers",
ENTRY_ATTR: "registers", ),
ENTRY_NAME: "read_input_registers", ConfEntry(
ENTRY_FUNC: None, CALL_TYPE_REGISTER_INPUT,
}, "registers",
CALL_TYPE_WRITE_COIL: { "read_input_registers",
ENTRY_ATTR: "value", ),
ENTRY_NAME: "write_coil", ConfEntry(
ENTRY_FUNC: None, CALL_TYPE_WRITE_COIL,
}, "value",
CALL_TYPE_WRITE_COILS: { "write_coil",
ENTRY_ATTR: "count", ),
ENTRY_NAME: "write_coils", ConfEntry(
ENTRY_FUNC: None, CALL_TYPE_WRITE_COILS,
}, "count",
CALL_TYPE_WRITE_REGISTER: { "write_coils",
ENTRY_ATTR: "value", ),
ENTRY_NAME: "write_register", ConfEntry(
ENTRY_FUNC: None, CALL_TYPE_WRITE_REGISTER,
}, "value",
CALL_TYPE_WRITE_REGISTERS: { "write_register",
ENTRY_ATTR: "count", ),
ENTRY_NAME: "write_registers", ConfEntry(
ENTRY_FUNC: None, CALL_TYPE_WRITE_REGISTERS,
}, "count",
} "write_registers",
),
]
async def async_modbus_setup( async def async_modbus_setup(
@ -197,7 +195,7 @@ class ModbusHub:
self._config_name = client_config[CONF_NAME] self._config_name = client_config[CONF_NAME]
self._config_type = client_config[CONF_TYPE] self._config_type = client_config[CONF_TYPE]
self._config_delay = client_config[CONF_DELAY] self._config_delay = client_config[CONF_DELAY]
self._pb_call = deepcopy(PYMODBUS_CALL) self._pb_call = {}
self._pb_class = { self._pb_class = {
CONF_SERIAL: ModbusSerialClient, CONF_SERIAL: ModbusSerialClient,
CONF_TCP: ModbusTcpClient, CONF_TCP: ModbusTcpClient,
@ -246,8 +244,9 @@ class ModbusHub:
self._log_error(str(exception_error), error_state=False) self._log_error(str(exception_error), error_state=False)
return False return False
for entry in self._pb_call.values(): for entry in PYMODBUS_CALL:
entry[ENTRY_FUNC] = getattr(self._client, entry[ENTRY_NAME]) func = getattr(self._client, entry.func_name)
self._pb_call[entry.call_type] = RunEntry(entry.attr, func)
await self.async_connect_task() await self.async_connect_task()
return True return True
@ -301,12 +300,13 @@ class ModbusHub:
def _pymodbus_call(self, unit, address, value, use_call): def _pymodbus_call(self, unit, address, value, use_call):
"""Call sync. pymodbus.""" """Call sync. pymodbus."""
kwargs = {"unit": unit} if unit else {} kwargs = {"unit": unit} if unit else {}
entry = self._pb_call[use_call]
try: try:
result = self._pb_call[use_call][ENTRY_FUNC](address, value, **kwargs) result = entry.func(address, value, **kwargs)
except ModbusException as exception_error: except ModbusException as exception_error:
self._log_error(str(exception_error)) self._log_error(str(exception_error))
return None return None
if not hasattr(result, self._pb_call[use_call][ENTRY_ATTR]): if not hasattr(result, entry.attr):
self._log_error(str(result)) self._log_error(str(result))
return None return None
self._in_error = False self._in_error = False

View File

@ -1,6 +1,7 @@
"""Validate Modbus configuration.""" """Validate Modbus configuration."""
from __future__ import annotations from __future__ import annotations
from collections import namedtuple
import logging import logging
import struct import struct
from typing import Any from typing import Any
@ -29,12 +30,12 @@ from .const import (
DATA_TYPE_INT16, DATA_TYPE_INT16,
DATA_TYPE_INT32, DATA_TYPE_INT32,
DATA_TYPE_INT64, DATA_TYPE_INT64,
DATA_TYPE_STRING,
DATA_TYPE_UINT, DATA_TYPE_UINT,
DATA_TYPE_UINT16, DATA_TYPE_UINT16,
DATA_TYPE_UINT32, DATA_TYPE_UINT32,
DATA_TYPE_UINT64, DATA_TYPE_UINT64,
DEFAULT_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL,
DEFAULT_STRUCT_FORMAT,
PLATFORMS, PLATFORMS,
) )
@ -57,6 +58,19 @@ OLD_DATA_TYPES = {
4: DATA_TYPE_FLOAT64, 4: DATA_TYPE_FLOAT64,
}, },
} }
ENTRY = namedtuple("ENTRY", ["struct_id", "register_count"])
DEFAULT_STRUCT_FORMAT = {
DATA_TYPE_INT16: ENTRY("h", 1),
DATA_TYPE_INT32: ENTRY("i", 2),
DATA_TYPE_INT64: ENTRY("q", 4),
DATA_TYPE_UINT16: ENTRY("H", 1),
DATA_TYPE_UINT32: ENTRY("I", 2),
DATA_TYPE_UINT64: ENTRY("Q", 4),
DATA_TYPE_FLOAT16: ENTRY("e", 1),
DATA_TYPE_FLOAT32: ENTRY("f", 2),
DATA_TYPE_FLOAT64: ENTRY("d", 4),
DATA_TYPE_STRING: ENTRY("s", 1),
}
def struct_validator(config): def struct_validator(config):
@ -79,9 +93,9 @@ def struct_validator(config):
if structure: if structure:
error = f"{name} structure: cannot be mixed with {data_type}" error = f"{name} structure: cannot be mixed with {data_type}"
raise vol.Invalid(error) raise vol.Invalid(error)
structure = f">{DEFAULT_STRUCT_FORMAT[data_type][0]}" structure = f">{DEFAULT_STRUCT_FORMAT[data_type].struct_id}"
if CONF_COUNT not in config: if CONF_COUNT not in config:
config[CONF_COUNT] = DEFAULT_STRUCT_FORMAT[data_type][1] config[CONF_COUNT] = DEFAULT_STRUCT_FORMAT[data_type].register_count
else: else:
if not structure: if not structure:
error = ( error = (