From 1cbd3ab9307fed9e75f898bbe2c4f66a8c8990f5 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 7 Jun 2024 13:09:48 -0500 Subject: [PATCH] Fix refactoring error in snmp switch (#119028) --- homeassistant/components/snmp/switch.py | 76 ++++++++++++++----------- homeassistant/components/snmp/util.py | 36 +++++++++--- 2 files changed, 72 insertions(+), 40 deletions(-) diff --git a/homeassistant/components/snmp/switch.py b/homeassistant/components/snmp/switch.py index 40083ed4213..02a94aeb8c1 100644 --- a/homeassistant/components/snmp/switch.py +++ b/homeassistant/components/snmp/switch.py @@ -8,6 +8,8 @@ from typing import Any import pysnmp.hlapi.asyncio as hlapi from pysnmp.hlapi.asyncio import ( CommunityData, + ObjectIdentity, + ObjectType, UdpTransportTarget, UsmUserData, getCmd, @@ -63,7 +65,12 @@ from .const import ( MAP_PRIV_PROTOCOLS, SNMP_VERSIONS, ) -from .util import RequestArgsType, async_create_request_cmd_args +from .util import ( + CommandArgsType, + RequestArgsType, + async_create_command_cmd_args, + async_create_request_cmd_args, +) _LOGGER = logging.getLogger(__name__) @@ -125,23 +132,23 @@ async def async_setup_platform( discovery_info: DiscoveryInfoType | None = None, ) -> None: """Set up the SNMP switch.""" - name = config.get(CONF_NAME) - host = config.get(CONF_HOST) - port = config.get(CONF_PORT) + name: str = config[CONF_NAME] + host: str = config[CONF_HOST] + port: int = config[CONF_PORT] community = config.get(CONF_COMMUNITY) baseoid: str = config[CONF_BASEOID] - command_oid = config.get(CONF_COMMAND_OID) - command_payload_on = config.get(CONF_COMMAND_PAYLOAD_ON) - command_payload_off = config.get(CONF_COMMAND_PAYLOAD_OFF) + command_oid: str | None = config.get(CONF_COMMAND_OID) + command_payload_on: str | None = config.get(CONF_COMMAND_PAYLOAD_ON) + command_payload_off: str | None = config.get(CONF_COMMAND_PAYLOAD_OFF) version: str = config[CONF_VERSION] username = config.get(CONF_USERNAME) authkey = config.get(CONF_AUTH_KEY) authproto: str = config[CONF_AUTH_PROTOCOL] privkey = config.get(CONF_PRIV_KEY) privproto: str = config[CONF_PRIV_PROTOCOL] - payload_on = config.get(CONF_PAYLOAD_ON) - payload_off = config.get(CONF_PAYLOAD_OFF) - vartype = config.get(CONF_VARTYPE) + payload_on: str = config[CONF_PAYLOAD_ON] + payload_off: str = config[CONF_PAYLOAD_OFF] + vartype: str = config[CONF_VARTYPE] if version == "3": if not authkey: @@ -159,9 +166,11 @@ async def async_setup_platform( else: auth_data = CommunityData(community, mpModel=SNMP_VERSIONS[version]) + transport = UdpTransportTarget((host, port)) request_args = await async_create_request_cmd_args( - hass, auth_data, UdpTransportTarget((host, port)), baseoid + hass, auth_data, transport, baseoid ) + command_args = await async_create_command_cmd_args(hass, auth_data, transport) async_add_entities( [ @@ -177,6 +186,7 @@ async def async_setup_platform( command_payload_off, vartype, request_args, + command_args, ) ], True, @@ -188,21 +198,22 @@ class SnmpSwitch(SwitchEntity): def __init__( self, - name, - host, - port, - baseoid, - commandoid, - payload_on, - payload_off, - command_payload_on, - command_payload_off, - vartype, - request_args, + name: str, + host: str, + port: int, + baseoid: str, + commandoid: str | None, + payload_on: str, + payload_off: str, + command_payload_on: str | None, + command_payload_off: str | None, + vartype: str, + request_args: RequestArgsType, + command_args: CommandArgsType, ) -> None: """Initialize the switch.""" - self._name = name + self._attr_name = name self._baseoid = baseoid self._vartype = vartype @@ -215,7 +226,8 @@ class SnmpSwitch(SwitchEntity): self._payload_on = payload_on self._payload_off = payload_off self._target = UdpTransportTarget((host, port)) - self._request_args: RequestArgsType = request_args + self._request_args = request_args + self._command_args = command_args async def async_turn_on(self, **kwargs: Any) -> None: """Turn on the switch.""" @@ -226,7 +238,7 @@ class SnmpSwitch(SwitchEntity): """Turn off the switch.""" await self._execute_command(self._command_payload_off) - async def _execute_command(self, command): + async def _execute_command(self, command: str) -> None: # User did not set vartype and command is not a digit if self._vartype == "none" and not self._command_payload_on.isdigit(): await self._set(command) @@ -265,14 +277,12 @@ class SnmpSwitch(SwitchEntity): self._state = None @property - def name(self): - """Return the switch's name.""" - return self._name - - @property - def is_on(self): + def is_on(self) -> bool | None: """Return true if switch is on; False if off. None if unknown.""" return self._state - async def _set(self, value): - await setCmd(*self._request_args, value) + async def _set(self, value: Any) -> None: + """Set the state of the switch.""" + await setCmd( + *self._command_args, ObjectType(ObjectIdentity(self._commandoid), value) + ) diff --git a/homeassistant/components/snmp/util.py b/homeassistant/components/snmp/util.py index 23adbdf0b90..dd3e9a6b6d2 100644 --- a/homeassistant/components/snmp/util.py +++ b/homeassistant/components/snmp/util.py @@ -25,6 +25,14 @@ DATA_SNMP_ENGINE = "snmp_engine" _LOGGER = logging.getLogger(__name__) +type CommandArgsType = tuple[ + SnmpEngine, + UsmUserData | CommunityData, + UdpTransportTarget | Udp6TransportTarget, + ContextData, +] + + type RequestArgsType = tuple[ SnmpEngine, UsmUserData | CommunityData, @@ -34,20 +42,34 @@ type RequestArgsType = tuple[ ] +async def async_create_command_cmd_args( + hass: HomeAssistant, + auth_data: UsmUserData | CommunityData, + target: UdpTransportTarget | Udp6TransportTarget, +) -> CommandArgsType: + """Create command arguments. + + The ObjectType needs to be created dynamically by the caller. + """ + engine = await async_get_snmp_engine(hass) + return (engine, auth_data, target, ContextData()) + + async def async_create_request_cmd_args( hass: HomeAssistant, auth_data: UsmUserData | CommunityData, target: UdpTransportTarget | Udp6TransportTarget, object_id: str, ) -> RequestArgsType: - """Create request arguments.""" - return ( - await async_get_snmp_engine(hass), - auth_data, - target, - ContextData(), - ObjectType(ObjectIdentity(object_id)), + """Create request arguments. + + The same ObjectType is used for all requests. + """ + engine, auth_data, target, context_data = await async_create_command_cmd_args( + hass, auth_data, target ) + object_type = ObjectType(ObjectIdentity(object_id)) + return (engine, auth_data, target, context_data, object_type) @singleton(DATA_SNMP_ENGINE)