Fix passing value to pymodbus low level function (#135108)

This commit is contained in:
Thijs W. 2025-01-22 12:33:21 +01:00 committed by GitHub
parent 1ea6cba1f5
commit 99d1c51a3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 33 additions and 13 deletions

View File

@ -72,48 +72,56 @@ from .validators import check_config
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
ConfEntry = namedtuple("ConfEntry", "call_type attr func_name") # noqa: PYI024 ConfEntry = namedtuple("ConfEntry", "call_type attr func_name value_attr_name") # noqa: PYI024
RunEntry = namedtuple("RunEntry", "attr func") # noqa: PYI024 RunEntry = namedtuple("RunEntry", "attr func value_attr_name") # noqa: PYI024
PB_CALL = [ PB_CALL = [
ConfEntry( ConfEntry(
CALL_TYPE_COIL, CALL_TYPE_COIL,
"bits", "bits",
"read_coils", "read_coils",
"count",
), ),
ConfEntry( ConfEntry(
CALL_TYPE_DISCRETE, CALL_TYPE_DISCRETE,
"bits", "bits",
"read_discrete_inputs", "read_discrete_inputs",
"count",
), ),
ConfEntry( ConfEntry(
CALL_TYPE_REGISTER_HOLDING, CALL_TYPE_REGISTER_HOLDING,
"registers", "registers",
"read_holding_registers", "read_holding_registers",
"count",
), ),
ConfEntry( ConfEntry(
CALL_TYPE_REGISTER_INPUT, CALL_TYPE_REGISTER_INPUT,
"registers", "registers",
"read_input_registers", "read_input_registers",
"count",
), ),
ConfEntry( ConfEntry(
CALL_TYPE_WRITE_COIL, CALL_TYPE_WRITE_COIL,
"value", "bits",
"write_coil", "write_coil",
"value",
), ),
ConfEntry( ConfEntry(
CALL_TYPE_WRITE_COILS, CALL_TYPE_WRITE_COILS,
"count", "count",
"write_coils", "write_coils",
"values",
), ),
ConfEntry( ConfEntry(
CALL_TYPE_WRITE_REGISTER, CALL_TYPE_WRITE_REGISTER,
"value", "registers",
"write_register", "write_register",
"value",
), ),
ConfEntry( ConfEntry(
CALL_TYPE_WRITE_REGISTERS, CALL_TYPE_WRITE_REGISTERS,
"count", "count",
"write_registers", "write_registers",
"values",
), ),
] ]
@ -322,7 +330,9 @@ class ModbusHub:
for entry in PB_CALL: for entry in PB_CALL:
func = getattr(self._client, entry.func_name) func = getattr(self._client, entry.func_name)
self._pb_request[entry.call_type] = RunEntry(entry.attr, func) self._pb_request[entry.call_type] = RunEntry(
entry.attr, func, entry.value_attr_name
)
self.hass.async_create_background_task( self.hass.async_create_background_task(
self.async_pb_connect(), "modbus-connect" self.async_pb_connect(), "modbus-connect"
@ -368,10 +378,11 @@ class ModbusHub:
self, slave: int | None, address: int, value: int | list[int], use_call: str self, slave: int | None, address: int, value: int | list[int], use_call: str
) -> ModbusPDU | None: ) -> ModbusPDU | None:
"""Call sync. pymodbus.""" """Call sync. pymodbus."""
kwargs = {"slave": slave} if slave else {} kwargs: dict[str, Any] = {"slave": slave} if slave else {}
entry = self._pb_request[use_call] entry = self._pb_request[use_call]
kwargs[entry.value_attr_name] = value
try: try:
result: ModbusPDU = await entry.func(address, value, **kwargs) result: ModbusPDU = await entry.func(address, **kwargs)
except ModbusException as exception_error: except ModbusException as exception_error:
error = f"Error: device: {slave} address: {address} -> {exception_error!s}" error = f"Error: device: {slave} address: {address} -> {exception_error!s}"
self._log_error(error) self._log_error(error)

View File

@ -394,7 +394,7 @@ async def test_hvac_onoff_values(hass: HomeAssistant, mock_modbus) -> None:
) )
await hass.async_block_till_done() await hass.async_block_till_done()
mock_modbus.write_register.assert_called_with(11, 0xAA, slave=10) mock_modbus.write_register.assert_called_with(11, value=0xAA, slave=10)
await hass.services.async_call( await hass.services.async_call(
CLIMATE_DOMAIN, CLIMATE_DOMAIN,
@ -404,7 +404,7 @@ async def test_hvac_onoff_values(hass: HomeAssistant, mock_modbus) -> None:
) )
await hass.async_block_till_done() await hass.async_block_till_done()
mock_modbus.write_register.assert_called_with(11, 0xFF, slave=10) mock_modbus.write_register.assert_called_with(11, value=0xFF, slave=10)
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -846,6 +846,13 @@ async def test_pb_service_write(
CALL_TYPE_WRITE_REGISTERS: mock_modbus_with_pymodbus.write_registers, CALL_TYPE_WRITE_REGISTERS: mock_modbus_with_pymodbus.write_registers,
} }
value_arg_name = {
CALL_TYPE_WRITE_COIL: "value",
CALL_TYPE_WRITE_COILS: "values",
CALL_TYPE_WRITE_REGISTER: "value",
CALL_TYPE_WRITE_REGISTERS: "values",
}
data = { data = {
ATTR_HUB: TEST_MODBUS_NAME, ATTR_HUB: TEST_MODBUS_NAME,
do_slave: 17, do_slave: 17,
@ -858,10 +865,12 @@ async def test_pb_service_write(
func_name[do_write[FUNC]].return_value = do_return[VALUE] func_name[do_write[FUNC]].return_value = do_return[VALUE]
await hass.services.async_call(DOMAIN, do_write[SERVICE], data, blocking=True) await hass.services.async_call(DOMAIN, do_write[SERVICE], data, blocking=True)
assert func_name[do_write[FUNC]].called assert func_name[do_write[FUNC]].called
assert func_name[do_write[FUNC]].call_args[0] == ( assert func_name[do_write[FUNC]].call_args.args == (data[ATTR_ADDRESS],)
data[ATTR_ADDRESS], assert func_name[do_write[FUNC]].call_args.kwargs == {
data[do_write[DATA]], "slave": 17,
) value_arg_name[do_write[FUNC]]: data[do_write[DATA]],
}
if do_return[DATA]: if do_return[DATA]:
assert any(message.startswith("Pymodbus:") for message in caplog.messages) assert any(message.startswith("Pymodbus:") for message in caplog.messages)