From d53d8f5ea94e95836ce3f3bdd373ee8ef71293dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20Sandstr=C3=B6m?= Date: Wed, 7 Sep 2016 03:21:38 +0200 Subject: [PATCH] thread safe modbus (#3188) --- homeassistant/components/modbus.py | 84 +++++++++++++++++++---- homeassistant/components/sensor/modbus.py | 7 +- homeassistant/components/switch/modbus.py | 21 +++--- 3 files changed, 81 insertions(+), 31 deletions(-) diff --git a/homeassistant/components/modbus.py b/homeassistant/components/modbus.py index 1d6ad0e3abc..4aab9ddc756 100644 --- a/homeassistant/components/modbus.py +++ b/homeassistant/components/modbus.py @@ -5,6 +5,7 @@ For more details about this component, please refer to the documentation at https://home-assistant.io/components/modbus/ """ import logging +import threading from homeassistant.const import ( EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP) @@ -37,7 +38,7 @@ ATTR_ADDRESS = "address" ATTR_UNIT = "unit" ATTR_VALUE = "value" -NETWORK = None +HUB = None TYPE = None @@ -50,34 +51,36 @@ def setup(hass, config): # Connect to Modbus network # pylint: disable=global-statement, import-error - global NETWORK if TYPE == "serial": from pymodbus.client.sync import ModbusSerialClient as ModbusClient - NETWORK = ModbusClient(method=config[DOMAIN][METHOD], - port=config[DOMAIN][SERIAL_PORT], - baudrate=config[DOMAIN][BAUDRATE], - stopbits=config[DOMAIN][STOPBITS], - bytesize=config[DOMAIN][BYTESIZE], - parity=config[DOMAIN][PARITY]) + client = ModbusClient(method=config[DOMAIN][METHOD], + port=config[DOMAIN][SERIAL_PORT], + baudrate=config[DOMAIN][BAUDRATE], + stopbits=config[DOMAIN][STOPBITS], + bytesize=config[DOMAIN][BYTESIZE], + parity=config[DOMAIN][PARITY]) elif TYPE == "tcp": from pymodbus.client.sync import ModbusTcpClient as ModbusClient - NETWORK = ModbusClient(host=config[DOMAIN][HOST], - port=config[DOMAIN][IP_PORT]) + client = ModbusClient(host=config[DOMAIN][HOST], + port=config[DOMAIN][IP_PORT]) elif TYPE == "udp": from pymodbus.client.sync import ModbusUdpClient as ModbusClient - NETWORK = ModbusClient(host=config[DOMAIN][HOST], - port=config[DOMAIN][IP_PORT]) + client = ModbusClient(host=config[DOMAIN][HOST], + port=config[DOMAIN][IP_PORT]) else: return False + global HUB + HUB = ModbusHub(client) + def stop_modbus(event): """Stop Modbus service.""" - NETWORK.close() + HUB.close() def start_modbus(event): """Start Modbus service.""" - NETWORK.connect() + HUB.connect() hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, stop_modbus) # Register services for modbus @@ -88,8 +91,59 @@ def setup(hass, config): unit = int(float(service.data.get(ATTR_UNIT))) address = int(float(service.data.get(ATTR_ADDRESS))) value = int(float(service.data.get(ATTR_VALUE))) - NETWORK.write_register(address, value, unit=unit) + HUB.write_register(unit, address, value) hass.bus.listen_once(EVENT_HOMEASSISTANT_START, start_modbus) return True + + +class ModbusHub(object): + """Thread safe wrapper class for pymodbus.""" + + def __init__(self, modbus_client): + """Initialize the modbus hub.""" + self._client = modbus_client + self._lock = threading.Lock() + + def close(self): + """Disconnect client.""" + with self._lock: + self._client.close() + + def connect(self): + """Connect client.""" + with self._lock: + self._client.connect() + + def read_coils(self, unit, address, count): + """Read coils.""" + with self._lock: + return self._client.read_coils( + address, + count, + unit=unit) + + def read_holding_registers(self, unit, address, count): + """Read holding registers.""" + with self._lock: + return self._client.read_holding_registers( + address, + count, + unit=unit) + + def write_coil(self, unit, address, value): + """Write coil.""" + with self._lock: + self._client.write_coil( + address, + value, + unit=unit) + + def write_register(self, unit, address, value): + """Write register.""" + with self._lock: + self._client.write_register( + address, + value, + unit=unit) diff --git a/homeassistant/components/sensor/modbus.py b/homeassistant/components/sensor/modbus.py index d6c85993162..063c1dc8600 100644 --- a/homeassistant/components/sensor/modbus.py +++ b/homeassistant/components/sensor/modbus.py @@ -114,12 +114,11 @@ class ModbusSensor(Entity): def update(self): """Update the state of the sensor.""" if self._coil: - result = modbus.NETWORK.read_coils(self.register, 1) + result = modbus.HUB.read_coils(self.slave, self.register, 1) self._value = result.bits[0] else: - result = modbus.NETWORK.read_holding_registers( - unit=self.slave, address=self.register, - count=1) + result = modbus.HUB.read_holding_registers( + self.slave, self.register, 1) val = 0 for i, res in enumerate(result.registers): val += res * (2**(i*16)) diff --git a/homeassistant/components/switch/modbus.py b/homeassistant/components/switch/modbus.py index 971947a6ed3..2ae0c74991d 100644 --- a/homeassistant/components/switch/modbus.py +++ b/homeassistant/components/switch/modbus.py @@ -90,12 +90,10 @@ class ModbusSwitch(ToggleEntity): self.update() if self._coil: - modbus.NETWORK.write_coil(self.register, True) + modbus.HUB.write_coil(self.slave, self.register, True) else: val = self.register_value | (0x0001 << self.bit) - modbus.NETWORK.write_register(unit=self.slave, - address=self.register, - value=val) + modbus.HUB.write_register(self.slave, self.register, val) def turn_off(self, **kwargs): """Set switch off.""" @@ -103,23 +101,22 @@ class ModbusSwitch(ToggleEntity): self.update() if self._coil: - modbus.NETWORK.write_coil(self.register, False) + modbus.HUB.write_coil(self.slave, self.register, False) else: val = self.register_value & ~(0x0001 << self.bit) - modbus.NETWORK.write_register(unit=self.slave, - address=self.register, - value=val) + modbus.HUB.write_register(self.slave, self.register, val) def update(self): """Update the state of the switch.""" if self._coil: - result = modbus.NETWORK.read_coils(self.register, 1) + result = modbus.HUB.read_coils(self.slave, self.register, 1) self.register_value = result.bits[0] self._is_on = self.register_value else: - result = modbus.NETWORK.read_holding_registers( - unit=self.slave, address=self.register, - count=1) + result = modbus.HUB.read_holding_registers( + self.slave, + self.register, + 1) val = 0 for i, res in enumerate(result.registers): val += res * (2**(i*16))