Fix modbus blocking threads (#50619)

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
jan iversen 2021-05-15 19:54:17 +02:00 committed by GitHub
parent 990b7c371f
commit ad7be91b6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 284 additions and 327 deletions

View File

@ -101,7 +101,7 @@ from .const import (
MODBUS_DOMAIN as DOMAIN, MODBUS_DOMAIN as DOMAIN,
PLATFORMS, PLATFORMS,
) )
from .modbus import modbus_setup from .modbus import async_modbus_setup
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -350,8 +350,8 @@ SERVICE_WRITE_COIL_SCHEMA = vol.Schema(
) )
def setup(hass, config): async def async_setup(hass, config):
"""Set up Modbus component.""" """Set up Modbus component."""
return modbus_setup( return await async_modbus_setup(
hass, config, SERVICE_WRITE_REGISTER_SCHEMA, SERVICE_WRITE_COIL_SCHEMA hass, config, SERVICE_WRITE_REGISTER_SCHEMA, SERVICE_WRITE_COIL_SCHEMA
) )

View File

@ -36,6 +36,7 @@ from .const import (
MODBUS_DOMAIN, MODBUS_DOMAIN,
) )
PARALLEL_UPDATES = 1
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -114,9 +115,7 @@ class ModbusBinarySensor(BinarySensorEntity):
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Handle entity which will be added.""" """Handle entity which will be added."""
async_track_time_interval( async_track_time_interval(self._hass, self.async_update, self._scan_interval)
self._hass, lambda arg: self.update(), self._scan_interval
)
@property @property
def name(self): def name(self):
@ -148,17 +147,21 @@ class ModbusBinarySensor(BinarySensorEntity):
"""Return True if entity is available.""" """Return True if entity is available."""
return self._available return self._available
def update(self): async def async_update(self, now=None):
"""Update the state of the sensor.""" """Update the state of the sensor."""
# remark "now" is a dummy parameter to avoid problems with
# async_track_time_interval
if self._input_type == CALL_TYPE_COIL: if self._input_type == CALL_TYPE_COIL:
result = self._hub.read_coils(self._slave, self._address, 1) result = await self._hub.async_read_coils(self._slave, self._address, 1)
else: else:
result = self._hub.read_discrete_inputs(self._slave, self._address, 1) result = await self._hub.async_read_discrete_inputs(
self._slave, self._address, 1
)
if result is None: if result is None:
self._available = False self._available = False
self.schedule_update_ha_state() self.async_write_ha_state()
return return
self._value = result.bits[0] & 1 self._value = result.bits[0] & 1
self._available = True self._available = True
self.schedule_update_ha_state() self.async_write_ha_state()

View File

@ -46,6 +46,7 @@ from .const import (
) )
from .modbus import ModbusHub from .modbus import ModbusHub
PARALLEL_UPDATES = 1
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -132,9 +133,7 @@ class ModbusThermostat(ClimateEntity):
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Handle entity which will be added.""" """Handle entity which will be added."""
async_track_time_interval( async_track_time_interval(self.hass, self.async_update, self._scan_interval)
self.hass, lambda arg: self.update(), self._scan_interval
)
@property @property
def should_poll(self): def should_poll(self):
@ -160,7 +159,7 @@ class ModbusThermostat(ClimateEntity):
"""Return the possible HVAC modes.""" """Return the possible HVAC modes."""
return [HVAC_MODE_AUTO] return [HVAC_MODE_AUTO]
def set_hvac_mode(self, hvac_mode: str) -> None: async def async_set_hvac_mode(self, hvac_mode: str) -> None:
"""Set new target hvac mode.""" """Set new target hvac mode."""
# Home Assistant expects this method. # Home Assistant expects this method.
# We'll keep it here to avoid getting exceptions. # We'll keep it here to avoid getting exceptions.
@ -200,7 +199,7 @@ class ModbusThermostat(ClimateEntity):
"""Return the supported step of target temperature.""" """Return the supported step of target temperature."""
return self._temp_step return self._temp_step
def set_temperature(self, **kwargs): async def async_set_temperature(self, **kwargs):
"""Set new target temperature.""" """Set new target temperature."""
if ATTR_TEMPERATURE not in kwargs: if ATTR_TEMPERATURE not in kwargs:
return return
@ -209,35 +208,39 @@ class ModbusThermostat(ClimateEntity):
) )
byte_string = struct.pack(self._structure, target_temperature) byte_string = struct.pack(self._structure, target_temperature)
register_value = struct.unpack(">h", byte_string[0:2])[0] register_value = struct.unpack(">h", byte_string[0:2])[0]
self._available = self._hub.write_registers( self._available = await self._hub.async_write_registers(
self._slave, self._slave,
self._target_temperature_register, self._target_temperature_register,
register_value, register_value,
) )
self.update() self.async_update()
@property @property
def available(self) -> bool: def available(self) -> bool:
"""Return True if entity is available.""" """Return True if entity is available."""
return self._available return self._available
def update(self): async def async_update(self, now=None):
"""Update Target & Current Temperature.""" """Update Target & Current Temperature."""
self._target_temperature = self._read_register( # remark "now" is a dummy parameter to avoid problems with
# async_track_time_interval
self._target_temperature = await self._async_read_register(
CALL_TYPE_REGISTER_HOLDING, self._target_temperature_register CALL_TYPE_REGISTER_HOLDING, self._target_temperature_register
) )
self._current_temperature = self._read_register( self._current_temperature = await self._async_read_register(
self._current_temperature_register_type, self._current_temperature_register self._current_temperature_register_type, self._current_temperature_register
) )
self.schedule_update_ha_state() self.async_write_ha_state()
def _read_register(self, register_type, register) -> float | None: async def _async_read_register(self, register_type, register) -> float | None:
"""Read register using the Modbus hub slave.""" """Read register using the Modbus hub slave."""
if register_type == CALL_TYPE_REGISTER_INPUT: if register_type == CALL_TYPE_REGISTER_INPUT:
result = self._hub.read_input_registers(self._slave, register, self._count) result = await self._hub.async_read_input_registers(
self._slave, register, self._count
)
else: else:
result = self._hub.read_holding_registers( result = await self._hub.async_read_holding_registers(
self._slave, register, self._count self._slave, register, self._count
) )
if result is None: if result is None:

View File

@ -33,6 +33,7 @@ from .const import (
) )
from .modbus import ModbusHub from .modbus import ModbusHub
PARALLEL_UPDATES = 1
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -106,9 +107,7 @@ class ModbusCover(CoverEntity, RestoreEntity):
if state: if state:
self._value = state.state self._value = state.state
async_track_time_interval( async_track_time_interval(self.hass, self.async_update, self._scan_interval)
self.hass, lambda arg: self.update(), self._scan_interval
)
@property @property
def device_class(self) -> str | None: def device_class(self) -> str | None:
@ -154,41 +153,43 @@ class ModbusCover(CoverEntity, RestoreEntity):
# Handle polling directly in this entity # Handle polling directly in this entity
return False return False
def open_cover(self, **kwargs: Any) -> None: async def async_open_cover(self, **kwargs: Any) -> None:
"""Open cover.""" """Open cover."""
if self._coil is not None: if self._coil is not None:
self._write_coil(True) await self._async_write_coil(True)
else: else:
self._write_register(self._state_open) await self._async_write_register(self._state_open)
self.update() self.async_update()
def close_cover(self, **kwargs: Any) -> None: async def async_close_cover(self, **kwargs: Any) -> None:
"""Close cover.""" """Close cover."""
if self._coil is not None: if self._coil is not None:
self._write_coil(False) await self._async_write_coil(False)
else: else:
self._write_register(self._state_closed) await self._async_write_register(self._state_closed)
self.update() self.async_update()
def update(self): async def async_update(self, now=None):
"""Update the state of the cover.""" """Update the state of the cover."""
# remark "now" is a dummy parameter to avoid problems with
# async_track_time_interval
if self._coil is not None and self._status_register is None: if self._coil is not None and self._status_register is None:
self._value = self._read_coil() self._value = await self._async_read_coil()
else: else:
self._value = self._read_status_register() self._value = await self._async_read_status_register()
self.schedule_update_ha_state() self.async_write_ha_state()
def _read_status_register(self) -> int | None: async def _async_read_status_register(self) -> int | None:
"""Read status register using the Modbus hub slave.""" """Read status register using the Modbus hub slave."""
if self._status_register_type == CALL_TYPE_REGISTER_INPUT: if self._status_register_type == CALL_TYPE_REGISTER_INPUT:
result = self._hub.read_input_registers( result = await self._hub.async_read_input_registers(
self._slave, self._status_register, 1 self._slave, self._status_register, 1
) )
else: else:
result = self._hub.read_holding_registers( result = await self._hub.async_read_holding_registers(
self._slave, self._status_register, 1 self._slave, self._status_register, 1
) )
if result is None: if result is None:
@ -200,13 +201,15 @@ class ModbusCover(CoverEntity, RestoreEntity):
return value return value
def _write_register(self, value): async def _async_write_register(self, value):
"""Write holding register using the Modbus hub slave.""" """Write holding register using the Modbus hub slave."""
self._available = self._hub.write_register(self._slave, self._register, value) self._available = await self._hub.async_write_register(
self._slave, self._register, value
)
def _read_coil(self) -> bool | None: async def _async_read_coil(self) -> bool | None:
"""Read coil using the Modbus hub slave.""" """Read coil using the Modbus hub slave."""
result = self._hub.read_coils(self._slave, self._coil, 1) result = await self._hub.async_read_coils(self._slave, self._coil, 1)
if result is None: if result is None:
self._available = False self._available = False
return None return None
@ -214,6 +217,8 @@ class ModbusCover(CoverEntity, RestoreEntity):
value = bool(result.bits[0] & 1) value = bool(result.bits[0] & 1)
return value return value
def _write_coil(self, value): async def _async_write_coil(self, value):
"""Write coil using the Modbus hub slave.""" """Write coil using the Modbus hub slave."""
self._available = self._hub.write_coil(self._slave, self._coil, value) self._available = await self._hub.async_write_coil(
self._slave, self._coil, value
)

View File

@ -1,6 +1,6 @@
"""Support for Modbus.""" """Support for Modbus."""
import asyncio
import logging import logging
import threading
from pymodbus.client.sync import ModbusSerialClient, ModbusTcpClient, ModbusUdpClient from pymodbus.client.sync import ModbusSerialClient, ModbusTcpClient, ModbusUdpClient
from pymodbus.constants import Defaults from pymodbus.constants import Defaults
@ -17,8 +17,9 @@ from homeassistant.const import (
CONF_TYPE, CONF_TYPE,
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_STOP,
) )
from homeassistant.helpers.discovery import load_platform from homeassistant.core import callback
from homeassistant.helpers.event import call_later from homeassistant.helpers.discovery import async_load_platform
from homeassistant.helpers.event import async_call_later
from .const import ( from .const import (
ATTR_ADDRESS, ATTR_ADDRESS,
@ -41,32 +42,37 @@ from .const import (
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def modbus_setup( async def async_modbus_setup(
hass, config, service_write_register_schema, service_write_coil_schema hass, config, service_write_register_schema, service_write_coil_schema
): ):
"""Set up Modbus component.""" """Set up Modbus component."""
hass.data[DOMAIN] = hub_collect = {} hass.data[DOMAIN] = hub_collect = {}
for conf_hub in config[DOMAIN]: for conf_hub in config[DOMAIN]:
hub_collect[conf_hub[CONF_NAME]] = ModbusHub(conf_hub) my_hub = ModbusHub(hass, conf_hub)
hub_collect[conf_hub[CONF_NAME]] = my_hub
# modbus needs to be activated before components are loaded # modbus needs to be activated before components are loaded
# to avoid a racing problem # to avoid a racing problem
hub_collect[conf_hub[CONF_NAME]].setup(hass) await my_hub.async_setup()
# load platforms # load platforms
for component, conf_key in PLATFORMS: for component, conf_key in PLATFORMS:
if conf_key in conf_hub: if conf_key in conf_hub:
load_platform(hass, component, DOMAIN, conf_hub, config) hass.async_create_task(
async_load_platform(hass, component, DOMAIN, conf_hub, config)
)
def stop_modbus(event): async def async_stop_modbus(event):
"""Stop Modbus service.""" """Stop Modbus service."""
for client in hub_collect.values(): for client in hub_collect.values():
client.close() await client.async_close()
del client del client
def write_register(service): hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_modbus)
async def async_write_register(service):
"""Write Modbus registers.""" """Write Modbus registers."""
unit = int(float(service.data[ATTR_UNIT])) unit = int(float(service.data[ATTR_UNIT]))
address = int(float(service.data[ATTR_ADDRESS])) address = int(float(service.data[ATTR_ADDRESS]))
@ -75,13 +81,22 @@ def modbus_setup(
service.data[ATTR_HUB] if ATTR_HUB in service.data else DEFAULT_HUB service.data[ATTR_HUB] if ATTR_HUB in service.data else DEFAULT_HUB
) )
if isinstance(value, list): if isinstance(value, list):
hub_collect[client_name].write_registers( await hub_collect[client_name].async_write_registers(
unit, address, [int(float(i)) for i in value] unit, address, [int(float(i)) for i in value]
) )
else: else:
hub_collect[client_name].write_register(unit, address, int(float(value))) await hub_collect[client_name].async_write_register(
unit, address, int(float(value))
)
def write_coil(service): hass.services.async_register(
DOMAIN,
SERVICE_WRITE_REGISTER,
async_write_register,
schema=service_write_register_schema,
)
async def async_write_coil(service):
"""Write Modbus coil.""" """Write Modbus coil."""
unit = service.data[ATTR_UNIT] unit = service.data[ATTR_UNIT]
address = service.data[ATTR_ADDRESS] address = service.data[ATTR_ADDRESS]
@ -90,22 +105,12 @@ def modbus_setup(
service.data[ATTR_HUB] if ATTR_HUB in service.data else DEFAULT_HUB service.data[ATTR_HUB] if ATTR_HUB in service.data else DEFAULT_HUB
) )
if isinstance(state, list): if isinstance(state, list):
hub_collect[client_name].write_coils(unit, address, state) await hub_collect[client_name].async_write_coils(unit, address, state)
else: else:
hub_collect[client_name].write_coil(unit, address, state) await hub_collect[client_name].async_write_coil(unit, address, state)
# register function to gracefully stop modbus hass.services.async_register(
hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, stop_modbus) DOMAIN, SERVICE_WRITE_COIL, async_write_coil, schema=service_write_coil_schema
# Register services for modbus
hass.services.register(
DOMAIN,
SERVICE_WRITE_REGISTER,
write_register,
schema=service_write_register_schema,
)
hass.services.register(
DOMAIN, SERVICE_WRITE_COIL, write_coil, schema=service_write_coil_schema
) )
return True return True
@ -113,14 +118,15 @@ def modbus_setup(
class ModbusHub: class ModbusHub:
"""Thread safe wrapper class for pymodbus.""" """Thread safe wrapper class for pymodbus."""
def __init__(self, client_config): def __init__(self, hass, client_config):
"""Initialize the Modbus hub.""" """Initialize the Modbus hub."""
# generic configuration # generic configuration
self._client = None self._client = None
self._cancel_listener = None self._async_cancel_listener = None
self._in_error = False self._in_error = False
self._lock = threading.Lock() self._lock = asyncio.Lock()
self.hass = hass
self._config_name = client_config[CONF_NAME] self._config_name = client_config[CONF_NAME]
self._config_type = client_config[CONF_TYPE] self._config_type = client_config[CONF_TYPE]
self._config_port = client_config[CONF_PORT] self._config_port = client_config[CONF_PORT]
@ -152,7 +158,7 @@ class ModbusHub:
_LOGGER.error(log_text) _LOGGER.error(log_text)
self._in_error = error_state self._in_error = error_state
def setup(self, hass): async def async_setup(self):
"""Set up pymodbus client.""" """Set up pymodbus client."""
try: try:
if self._config_type == "serial": if self._config_type == "serial":
@ -193,166 +199,113 @@ class ModbusHub:
self._log_error(exception_error, error_state=False) self._log_error(exception_error, error_state=False)
return return
# Connect device async with self._lock:
self.connect() await self.hass.async_add_executor_job(self._pymodbus_connect)
# Start counting down to allow modbus requests. # Start counting down to allow modbus requests.
if self._config_delay: if self._config_delay:
self._cancel_listener = call_later(hass, self._config_delay, self.end_delay) self._async_cancel_listener = async_call_later(
self.hass, self._config_delay, self.async_end_delay
)
def end_delay(self, args): @callback
def async_end_delay(self, args):
"""End startup delay.""" """End startup delay."""
self._cancel_listener = None self._async_cancel_listener = None
self._config_delay = 0 self._config_delay = 0
def close(self): def _pymodbus_close(self):
"""Close sync. pymodbus."""
if self._client:
try:
self._client.close()
except ModbusException as exception_error:
self._log_error(exception_error)
self._client = None
async def async_close(self):
"""Disconnect client.""" """Disconnect client."""
if self._cancel_listener: if self._async_cancel_listener:
self._cancel_listener() self._async_cancel_listener()
self._cancel_listener = None self._async_cancel_listener = None
with self._lock:
try:
if self._client:
self._client.close()
self._client = None
except ModbusException as exception_error:
self._log_error(exception_error)
return
def connect(self): async with self._lock:
return await self.hass.async_add_executor_job(self._pymodbus_close)
def _pymodbus_connect(self):
"""Connect client.""" """Connect client."""
with self._lock: try:
try: self._client.connect()
self._client.connect() except ModbusException as exception_error:
except ModbusException as exception_error: self._log_error(exception_error, error_state=False)
self._log_error(exception_error, error_state=False)
return
def read_coils(self, unit, address, count): def _pymodbus_call(self, unit, address, value, check_attr, func):
"""Call sync. pymodbus."""
kwargs = {"unit": unit} if unit else {}
try:
result = func(address, value, **kwargs)
except ModbusException as exception_error:
self._log_error(exception_error)
result = exception_error
if not hasattr(result, check_attr):
self._log_error(result)
return None
self._in_error = False
return result
async def async_pymodbus_call(self, unit, address, value, check_attr, func):
"""Convert async to sync pymodbus call."""
if self._config_delay:
return None
async with self._lock:
return await self.hass.async_add_executor_job(
self._pymodbus_call, unit, address, value, check_attr, func
)
async def async_read_coils(self, unit, address, count):
"""Read coils.""" """Read coils."""
if self._config_delay: return await self.async_pymodbus_call(
return None unit, address, count, "bits", self._client.read_coils
with self._lock: )
kwargs = {"unit": unit} if unit else {}
try:
result = self._client.read_coils(address, count, **kwargs)
except ModbusException as exception_error:
self._log_error(exception_error)
result = exception_error
if not hasattr(result, "bits"):
self._log_error(result)
return None
self._in_error = False
return result
def read_discrete_inputs(self, unit, address, count): async def async_read_discrete_inputs(self, unit, address, count):
"""Read discrete inputs.""" """Read discrete inputs."""
if self._config_delay: return await self.async_pymodbus_call(
return None unit, address, count, "bits", self._client.read_discrete_inputs
with self._lock: )
kwargs = {"unit": unit} if unit else {}
try:
result = self._client.read_discrete_inputs(address, count, **kwargs)
except ModbusException as exception_error:
result = exception_error
if not hasattr(result, "bits"):
self._log_error(result)
return None
self._in_error = False
return result
def read_input_registers(self, unit, address, count): async def async_read_input_registers(self, unit, address, count):
"""Read input registers.""" """Read input registers."""
if self._config_delay: return await self.async_pymodbus_call(
return None unit, address, count, "registers", self._client.read_input_registers
with self._lock: )
kwargs = {"unit": unit} if unit else {}
try:
result = self._client.read_input_registers(address, count, **kwargs)
except ModbusException as exception_error:
result = exception_error
if not hasattr(result, "registers"):
self._log_error(result)
return None
self._in_error = False
return result
def read_holding_registers(self, unit, address, count): async def async_read_holding_registers(self, unit, address, count):
"""Read holding registers.""" """Read holding registers."""
if self._config_delay: return await self.async_pymodbus_call(
return None unit, address, count, "registers", self._client.read_holding_registers
with self._lock: )
kwargs = {"unit": unit} if unit else {}
try:
result = self._client.read_holding_registers(address, count, **kwargs)
except ModbusException as exception_error:
result = exception_error
if not hasattr(result, "registers"):
self._log_error(result)
return None
self._in_error = False
return result
def write_coil(self, unit, address, value) -> bool: async def async_write_coil(self, unit, address, value) -> bool:
"""Write coil.""" """Write coil."""
if self._config_delay: return await self.async_pymodbus_call(
return False unit, address, value, "value", self._client.write_coil
with self._lock: )
kwargs = {"unit": unit} if unit else {}
try:
result = self._client.write_coil(address, value, **kwargs)
except ModbusException as exception_error:
result = exception_error
if not hasattr(result, "value"):
self._log_error(result)
return False
self._in_error = False
return True
def write_coils(self, unit, address, values) -> bool: async def async_write_coils(self, unit, address, values) -> bool:
"""Write coil.""" """Write coil."""
if self._config_delay: return await self.async_pymodbus_call(
return False unit, address, values, "count", self._client.write_coils
with self._lock: )
kwargs = {"unit": unit} if unit else {}
try:
result = self._client.write_coils(address, values, **kwargs)
except ModbusException as exception_error:
result = exception_error
if not hasattr(result, "count"):
self._log_error(result)
return False
self._in_error = False
return True
def write_register(self, unit, address, value) -> bool: async def async_write_register(self, unit, address, value) -> bool:
"""Write register.""" """Write register."""
if self._config_delay: return await self.async_pymodbus_call(
return False unit, address, value, "value", self._client.write_register
with self._lock: )
kwargs = {"unit": unit} if unit else {}
try:
result = self._client.write_register(address, value, **kwargs)
except ModbusException as exception_error:
result = exception_error
if not hasattr(result, "value"):
self._log_error(result)
return False
self._in_error = False
return True
def write_registers(self, unit, address, values) -> bool: async def async_write_registers(self, unit, address, values) -> bool:
"""Write registers.""" """Write registers."""
if self._config_delay: return await self.async_pymodbus_call(
return False unit, address, values, "count", self._client.write_registers
with self._lock: )
kwargs = {"unit": unit} if unit else {}
try:
result = self._client.write_registers(address, values, **kwargs)
except ModbusException as exception_error:
result = exception_error
if not hasattr(result, "count"):
self._log_error(result)
return False
self._in_error = False
return True

View File

@ -59,6 +59,7 @@ from .const import (
MODBUS_DOMAIN, MODBUS_DOMAIN,
) )
PARALLEL_UPDATES = 1
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -226,9 +227,7 @@ class ModbusRegisterSensor(RestoreEntity, SensorEntity):
if state: if state:
self._value = state.state self._value = state.state
async_track_time_interval( async_track_time_interval(self.hass, self.async_update, self._scan_interval)
self.hass, lambda arg: self.update(), self._scan_interval
)
@property @property
def state(self): def state(self):
@ -280,19 +279,21 @@ class ModbusRegisterSensor(RestoreEntity, SensorEntity):
registers.reverse() registers.reverse()
return registers return registers
def update(self): async def async_update(self, now=None):
"""Update the state of the sensor.""" """Update the state of the sensor."""
# remark "now" is a dummy parameter to avoid problems with
# async_track_time_interval
if self._register_type == CALL_TYPE_REGISTER_INPUT: if self._register_type == CALL_TYPE_REGISTER_INPUT:
result = self._hub.read_input_registers( result = await self._hub.async_read_input_registers(
self._slave, self._register, self._count self._slave, self._register, self._count
) )
else: else:
result = self._hub.read_holding_registers( result = await self._hub.async_read_holding_registers(
self._slave, self._register, self._count self._slave, self._register, self._count
) )
if result is None: if result is None:
self._available = False self._available = False
self.schedule_update_ha_state() self.async_write_ha_state()
return return
registers = self._swap_registers(result.registers) registers = self._swap_registers(result.registers)
@ -332,4 +333,4 @@ class ModbusRegisterSensor(RestoreEntity, SensorEntity):
self._value = f"{float(val):.{self._precision}f}" self._value = f"{float(val):.{self._precision}f}"
self._available = True self._available = True
self.schedule_update_ha_state() self.async_write_ha_state()

View File

@ -34,6 +34,7 @@ from .const import (
) )
from .modbus import ModbusHub from .modbus import ModbusHub
PARALLEL_UPDATES = 1
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -62,11 +63,11 @@ class ModbusSwitch(SwitchEntity, RestoreEntity):
self._scan_interval = timedelta(seconds=config[CONF_SCAN_INTERVAL]) self._scan_interval = timedelta(seconds=config[CONF_SCAN_INTERVAL])
self._address = config[CONF_ADDRESS] self._address = config[CONF_ADDRESS]
if config[CONF_WRITE_TYPE] == CALL_TYPE_COIL: if config[CONF_WRITE_TYPE] == CALL_TYPE_COIL:
self._write_func = self._hub.write_coil self._async_write_func = self._hub.async_write_coil
self._command_on = 0x01 self._command_on = 0x01
self._command_off = 0x00 self._command_off = 0x00
else: else:
self._write_func = self._hub.write_register self._async_write_func = self._hub.async_write_register
self._command_on = config[CONF_COMMAND_ON] self._command_on = config[CONF_COMMAND_ON]
self._command_off = config[CONF_COMMAND_OFF] self._command_off = config[CONF_COMMAND_OFF]
if CONF_VERIFY in config: if CONF_VERIFY in config:
@ -83,13 +84,13 @@ class ModbusSwitch(SwitchEntity, RestoreEntity):
self._state_off = config[CONF_VERIFY].get(CONF_STATE_OFF, self._command_off) self._state_off = config[CONF_VERIFY].get(CONF_STATE_OFF, self._command_off)
if self._verify_type == CALL_TYPE_REGISTER_HOLDING: if self._verify_type == CALL_TYPE_REGISTER_HOLDING:
self._read_func = self._hub.read_holding_registers self._async_read_func = self._hub.async_read_holding_registers
elif self._verify_type == CALL_TYPE_DISCRETE: elif self._verify_type == CALL_TYPE_DISCRETE:
self._read_func = self._hub.read_discrete_inputs self._async_read_func = self._hub.async_read_discrete_inputs
elif self._verify_type == CALL_TYPE_REGISTER_INPUT: elif self._verify_type == CALL_TYPE_REGISTER_INPUT:
self._read_func = self._hub.read_input_registers self._async_read_func = self._hub.async_read_input_registers
else: # self._verify_type == CALL_TYPE_COIL: else: # self._verify_type == CALL_TYPE_COIL:
self._read_func = self._hub.read_coils self._async_read_func = self._hub.async_read_coils
else: else:
self._verify_active = False self._verify_active = False
@ -99,9 +100,7 @@ class ModbusSwitch(SwitchEntity, RestoreEntity):
if state: if state:
self._is_on = state.state == STATE_ON self._is_on = state.state == STATE_ON
async_track_time_interval( async_track_time_interval(self.hass, self.async_update, self._scan_interval)
self.hass, lambda arg: self.update(), self._scan_interval
)
@property @property
def is_on(self): def is_on(self):
@ -123,46 +122,52 @@ class ModbusSwitch(SwitchEntity, RestoreEntity):
"""Return True if entity is available.""" """Return True if entity is available."""
return self._available return self._available
def turn_on(self, **kwargs): async def async_turn_on(self, **kwargs):
"""Set switch on.""" """Set switch on."""
result = self._write_func(self._slave, self._address, self._command_on) result = await self._async_write_func(
self._slave, self._address, self._command_on
)
if result is False: if result is False:
self._available = False self._available = False
self.schedule_update_ha_state() self.async_write_ha_state()
else: else:
self._available = True self._available = True
if self._verify_active: if self._verify_active:
self.update() self.async_update()
else: else:
self._is_on = True self._is_on = True
self.schedule_update_ha_state() self.async_write_ha_state()
def turn_off(self, **kwargs): async def async_turn_off(self, **kwargs):
"""Set switch off.""" """Set switch off."""
result = self._write_func(self._slave, self._address, self._command_off) result = await self._async_write_func(
self._slave, self._address, self._command_off
)
if result is False: if result is False:
self._available = False self._available = False
self.schedule_update_ha_state() self.async_write_ha_state()
else: else:
self._available = True self._available = True
if self._verify_active: if self._verify_active:
self.update() self.async_update()
else: else:
self._is_on = False self._is_on = False
self.schedule_update_ha_state() self.async_write_ha_state()
def update(self): async def async_update(self, now=None):
"""Update the entity state.""" """Update the entity state."""
# remark "now" is a dummy parameter to avoid problems with
# async_track_time_interval
if not self._verify_active: if not self._verify_active:
self._available = True self._available = True
self.schedule_update_ha_state() self.async_write_ha_state()
return return
result = self._read_func(self._slave, self._verify_address, 1) result = await self._async_read_func(self._slave, self._verify_address, 1)
if result is None: if result is None:
self._available = False self._available = False
self.schedule_update_ha_state() self.async_write_ha_state()
return return
self._available = True self._available = True
@ -182,4 +187,4 @@ class ModbusSwitch(SwitchEntity, RestoreEntity):
self._verify_address, self._verify_address,
value, value,
) )
self.schedule_update_ha_state() self.async_write_ha_state()

View File

@ -480,11 +480,13 @@ async def test_pymodbus_connect_fail(hass, caplog, mock_pymodbus):
async def test_delay(hass, mock_pymodbus): async def test_delay(hass, mock_pymodbus):
"""Run test for different read.""" """Run test for startup delay."""
# the purpose of this test is to test startup delay # the purpose of this test is to test startup delay
# We "hijiack" binary_sensor and sensor in order # We "hijiack" a binary_sensor to make a proper blackbox test.
# to make a proper blackbox test. test_delay = 15
test_scan_interval = 5
entity_id = f"{BINARY_SENSOR_DOMAIN}.{TEST_SENSOR_NAME}"
config = { config = {
DOMAIN: [ DOMAIN: [
{ {
@ -492,101 +494,86 @@ async def test_delay(hass, mock_pymodbus):
CONF_HOST: "modbusTestHost", CONF_HOST: "modbusTestHost",
CONF_PORT: 5501, CONF_PORT: 5501,
CONF_NAME: TEST_MODBUS_NAME, CONF_NAME: TEST_MODBUS_NAME,
CONF_DELAY: 15, CONF_DELAY: test_delay,
CONF_BINARY_SENSORS: [ CONF_BINARY_SENSORS: [
{ {
CONF_INPUT_TYPE: CALL_TYPE_COIL, CONF_INPUT_TYPE: CALL_TYPE_COIL,
CONF_NAME: f"{TEST_SENSOR_NAME}_2", CONF_NAME: f"{TEST_SENSOR_NAME}",
CONF_ADDRESS: 52, CONF_ADDRESS: 52,
CONF_SCAN_INTERVAL: 5, CONF_SCAN_INTERVAL: test_scan_interval,
},
{
CONF_INPUT_TYPE: CALL_TYPE_DISCRETE,
CONF_NAME: f"{TEST_SENSOR_NAME}_1",
CONF_ADDRESS: 51,
CONF_SCAN_INTERVAL: 5,
},
],
CONF_SENSORS: [
{
CONF_INPUT_TYPE: CALL_TYPE_REGISTER_HOLDING,
CONF_NAME: f"{TEST_SENSOR_NAME}_3",
CONF_ADDRESS: 53,
CONF_SCAN_INTERVAL: 5,
},
{
CONF_INPUT_TYPE: CALL_TYPE_REGISTER_INPUT,
CONF_NAME: f"{TEST_SENSOR_NAME}_4",
CONF_ADDRESS: 54,
CONF_SCAN_INTERVAL: 5,
}, },
], ],
} }
] ]
} }
mock_pymodbus.read_coils.return_value = ReadResult([0x01]) mock_pymodbus.read_coils.return_value = ReadResult([0x01])
mock_pymodbus.read_discrete_inputs.return_value = ReadResult([0x01])
mock_pymodbus.read_holding_registers.return_value = ReadResult([7])
mock_pymodbus.read_input_registers.return_value = ReadResult([7])
now = dt_util.utcnow() now = dt_util.utcnow()
with mock.patch("homeassistant.helpers.event.dt_util.utcnow", return_value=now): with mock.patch("homeassistant.helpers.event.dt_util.utcnow", return_value=now):
assert await async_setup_component(hass, DOMAIN, config) is True assert await async_setup_component(hass, DOMAIN, config) is True
await hass.async_block_till_done() await hass.async_block_till_done()
now = now + timedelta(seconds=10) # pass first scan_interval
start_time = now
now = now + timedelta(seconds=(test_scan_interval + 1))
with mock.patch("homeassistant.helpers.event.dt_util.utcnow", return_value=now): with mock.patch("homeassistant.helpers.event.dt_util.utcnow", return_value=now):
async_fire_time_changed(hass, now) async_fire_time_changed(hass, now)
await hass.async_block_till_done() await hass.async_block_till_done()
assert hass.states.get(entity_id).state == STATE_UNAVAILABLE
# Check states stop_time = start_time + timedelta(seconds=(test_delay + 1))
entity_id = f"{BINARY_SENSOR_DOMAIN}.{TEST_SENSOR_NAME}_1" step_timedelta = timedelta(seconds=1)
assert hass.states.get(entity_id).state == STATE_UNAVAILABLE while now < stop_time:
entity_id = f"{BINARY_SENSOR_DOMAIN}.{TEST_SENSOR_NAME}_2" now = now + step_timedelta
assert hass.states.get(entity_id).state == STATE_UNAVAILABLE with mock.patch("homeassistant.helpers.event.dt_util.utcnow", return_value=now):
entity_id = f"{SENSOR_DOMAIN}.{TEST_SENSOR_NAME}_3" async_fire_time_changed(hass, now)
assert hass.states.get(entity_id).state == STATE_UNAVAILABLE await hass.async_block_till_done()
entity_id = f"{SENSOR_DOMAIN}.{TEST_SENSOR_NAME}_4" assert hass.states.get(entity_id).state == STATE_UNAVAILABLE
assert hass.states.get(entity_id).state == STATE_UNAVAILABLE now = now + step_timedelta + timedelta(seconds=2)
with mock.patch("homeassistant.helpers.event.dt_util.utcnow", return_value=now):
async_fire_time_changed(hass, now)
await hass.async_block_till_done()
assert hass.states.get(entity_id).state == STATE_ON
mock_pymodbus.reset_mock()
data = { async def test_thread_lock(hass, mock_pymodbus):
ATTR_HUB: TEST_MODBUS_NAME, """Run test for block of threads."""
ATTR_UNIT: 17,
ATTR_ADDRESS: 16, # the purpose of this test is to test the threads are not being blocked
ATTR_STATE: False, # We "hijiack" a binary_sensor to make a proper blackbox test.
test_scan_interval = 5
sensors = []
for i in range(200):
sensors.append(
{
CONF_INPUT_TYPE: CALL_TYPE_COIL,
CONF_NAME: f"{TEST_SENSOR_NAME}_{i}",
CONF_ADDRESS: 52 + i,
CONF_SCAN_INTERVAL: test_scan_interval,
}
)
config = {
DOMAIN: [
{
CONF_TYPE: "tcp",
CONF_HOST: "modbusTestHost",
CONF_PORT: 5501,
CONF_NAME: TEST_MODBUS_NAME,
CONF_BINARY_SENSORS: sensors,
}
]
} }
await hass.services.async_call(DOMAIN, SERVICE_WRITE_COIL, data, blocking=True) mock_pymodbus.read_coils.return_value = ReadResult([0x01])
assert not mock_pymodbus.write_coil.called now = dt_util.utcnow()
await hass.services.async_call(DOMAIN, SERVICE_WRITE_COIL, data, blocking=True)
assert not mock_pymodbus.write_coil.called
data[ATTR_STATE] = [True, False, True]
await hass.services.async_call(DOMAIN, SERVICE_WRITE_COIL, data, blocking=True)
assert not mock_pymodbus.write_coils.called
del data[ATTR_STATE]
data[ATTR_VALUE] = 15
await hass.services.async_call(DOMAIN, SERVICE_WRITE_REGISTER, data, blocking=True)
assert not mock_pymodbus.write_register.called
data[ATTR_VALUE] = [1, 2, 3]
await hass.services.async_call(DOMAIN, SERVICE_WRITE_REGISTER, data, blocking=True)
assert not mock_pymodbus.write_registers.called
# 2 times fire_changed is needed to secure "normal" update is called.
now = now + timedelta(seconds=6)
with mock.patch("homeassistant.helpers.event.dt_util.utcnow", return_value=now): with mock.patch("homeassistant.helpers.event.dt_util.utcnow", return_value=now):
async_fire_time_changed(hass, now) assert await async_setup_component(hass, DOMAIN, config) is True
await hass.async_block_till_done() await hass.async_block_till_done()
now = now + timedelta(seconds=10) stop_time = now + timedelta(seconds=10)
with mock.patch("homeassistant.helpers.event.dt_util.utcnow", return_value=now): step_timedelta = timedelta(seconds=1)
async_fire_time_changed(hass, now) while now < stop_time:
await hass.async_block_till_done() now = now + step_timedelta
with mock.patch("homeassistant.helpers.event.dt_util.utcnow", return_value=now):
# Check states async_fire_time_changed(hass, now)
entity_id = f"{BINARY_SENSOR_DOMAIN}.{TEST_SENSOR_NAME}_1" await hass.async_block_till_done()
assert not hass.states.get(entity_id).state == STATE_UNAVAILABLE for i in range(200):
entity_id = f"{BINARY_SENSOR_DOMAIN}.{TEST_SENSOR_NAME}_2" entity_id = f"{BINARY_SENSOR_DOMAIN}.{TEST_SENSOR_NAME}_{i}"
assert not hass.states.get(entity_id).state == STATE_UNAVAILABLE assert hass.states.get(entity_id).state == STATE_ON
entity_id = f"{SENSOR_DOMAIN}.{TEST_SENSOR_NAME}_3"
assert not hass.states.get(entity_id).state == STATE_UNAVAILABLE
entity_id = f"{SENSOR_DOMAIN}.{TEST_SENSOR_NAME}_4"
assert not hass.states.get(entity_id).state == STATE_UNAVAILABLE