Fix refactoring error in snmp switch (#119028)

This commit is contained in:
J. Nick Koston 2024-06-07 13:09:48 -05:00 committed by GitHub
parent cd7f2f9f77
commit 440185be25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 72 additions and 40 deletions

View File

@ -8,6 +8,8 @@ from typing import Any
import pysnmp.hlapi.asyncio as hlapi import pysnmp.hlapi.asyncio as hlapi
from pysnmp.hlapi.asyncio import ( from pysnmp.hlapi.asyncio import (
CommunityData, CommunityData,
ObjectIdentity,
ObjectType,
UdpTransportTarget, UdpTransportTarget,
UsmUserData, UsmUserData,
getCmd, getCmd,
@ -63,7 +65,12 @@ from .const import (
MAP_PRIV_PROTOCOLS, MAP_PRIV_PROTOCOLS,
SNMP_VERSIONS, SNMP_VERSIONS,
) )
from .util import RequestArgsType, async_create_request_cmd_args from .util import (
CommandArgsType,
RequestArgsType,
async_create_command_cmd_args,
async_create_request_cmd_args,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -125,23 +132,23 @@ async def async_setup_platform(
discovery_info: DiscoveryInfoType | None = None, discovery_info: DiscoveryInfoType | None = None,
) -> None: ) -> None:
"""Set up the SNMP switch.""" """Set up the SNMP switch."""
name = config.get(CONF_NAME) name: str = config[CONF_NAME]
host = config.get(CONF_HOST) host: str = config[CONF_HOST]
port = config.get(CONF_PORT) port: int = config[CONF_PORT]
community = config.get(CONF_COMMUNITY) community = config.get(CONF_COMMUNITY)
baseoid: str = config[CONF_BASEOID] baseoid: str = config[CONF_BASEOID]
command_oid = config.get(CONF_COMMAND_OID) command_oid: str | None = config.get(CONF_COMMAND_OID)
command_payload_on = config.get(CONF_COMMAND_PAYLOAD_ON) command_payload_on: str | None = config.get(CONF_COMMAND_PAYLOAD_ON)
command_payload_off = config.get(CONF_COMMAND_PAYLOAD_OFF) command_payload_off: str | None = config.get(CONF_COMMAND_PAYLOAD_OFF)
version: str = config[CONF_VERSION] version: str = config[CONF_VERSION]
username = config.get(CONF_USERNAME) username = config.get(CONF_USERNAME)
authkey = config.get(CONF_AUTH_KEY) authkey = config.get(CONF_AUTH_KEY)
authproto: str = config[CONF_AUTH_PROTOCOL] authproto: str = config[CONF_AUTH_PROTOCOL]
privkey = config.get(CONF_PRIV_KEY) privkey = config.get(CONF_PRIV_KEY)
privproto: str = config[CONF_PRIV_PROTOCOL] privproto: str = config[CONF_PRIV_PROTOCOL]
payload_on = config.get(CONF_PAYLOAD_ON) payload_on: str = config[CONF_PAYLOAD_ON]
payload_off = config.get(CONF_PAYLOAD_OFF) payload_off: str = config[CONF_PAYLOAD_OFF]
vartype = config.get(CONF_VARTYPE) vartype: str = config[CONF_VARTYPE]
if version == "3": if version == "3":
if not authkey: if not authkey:
@ -159,9 +166,11 @@ async def async_setup_platform(
else: else:
auth_data = CommunityData(community, mpModel=SNMP_VERSIONS[version]) auth_data = CommunityData(community, mpModel=SNMP_VERSIONS[version])
transport = UdpTransportTarget((host, port))
request_args = await async_create_request_cmd_args( request_args = await async_create_request_cmd_args(
hass, auth_data, UdpTransportTarget((host, port)), baseoid hass, auth_data, transport, baseoid
) )
command_args = await async_create_command_cmd_args(hass, auth_data, transport)
async_add_entities( async_add_entities(
[ [
@ -177,6 +186,7 @@ async def async_setup_platform(
command_payload_off, command_payload_off,
vartype, vartype,
request_args, request_args,
command_args,
) )
], ],
True, True,
@ -188,21 +198,22 @@ class SnmpSwitch(SwitchEntity):
def __init__( def __init__(
self, self,
name, name: str,
host, host: str,
port, port: int,
baseoid, baseoid: str,
commandoid, commandoid: str | None,
payload_on, payload_on: str,
payload_off, payload_off: str,
command_payload_on, command_payload_on: str | None,
command_payload_off, command_payload_off: str | None,
vartype, vartype: str,
request_args, request_args: RequestArgsType,
command_args: CommandArgsType,
) -> None: ) -> None:
"""Initialize the switch.""" """Initialize the switch."""
self._name = name self._attr_name = name
self._baseoid = baseoid self._baseoid = baseoid
self._vartype = vartype self._vartype = vartype
@ -215,7 +226,8 @@ class SnmpSwitch(SwitchEntity):
self._payload_on = payload_on self._payload_on = payload_on
self._payload_off = payload_off self._payload_off = payload_off
self._target = UdpTransportTarget((host, port)) self._target = UdpTransportTarget((host, port))
self._request_args: RequestArgsType = request_args self._request_args = request_args
self._command_args = command_args
async def async_turn_on(self, **kwargs: Any) -> None: async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn on the switch.""" """Turn on the switch."""
@ -226,7 +238,7 @@ class SnmpSwitch(SwitchEntity):
"""Turn off the switch.""" """Turn off the switch."""
await self._execute_command(self._command_payload_off) await self._execute_command(self._command_payload_off)
async def _execute_command(self, command): async def _execute_command(self, command: str) -> None:
# User did not set vartype and command is not a digit # User did not set vartype and command is not a digit
if self._vartype == "none" and not self._command_payload_on.isdigit(): if self._vartype == "none" and not self._command_payload_on.isdigit():
await self._set(command) await self._set(command)
@ -265,14 +277,12 @@ class SnmpSwitch(SwitchEntity):
self._state = None self._state = None
@property @property
def name(self): def is_on(self) -> bool | None:
"""Return the switch's name."""
return self._name
@property
def is_on(self):
"""Return true if switch is on; False if off. None if unknown.""" """Return true if switch is on; False if off. None if unknown."""
return self._state return self._state
async def _set(self, value): async def _set(self, value: Any) -> None:
await setCmd(*self._request_args, value) """Set the state of the switch."""
await setCmd(
*self._command_args, ObjectType(ObjectIdentity(self._commandoid), value)
)

View File

@ -25,6 +25,14 @@ DATA_SNMP_ENGINE = "snmp_engine"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
type CommandArgsType = tuple[
SnmpEngine,
UsmUserData | CommunityData,
UdpTransportTarget | Udp6TransportTarget,
ContextData,
]
type RequestArgsType = tuple[ type RequestArgsType = tuple[
SnmpEngine, SnmpEngine,
UsmUserData | CommunityData, UsmUserData | CommunityData,
@ -34,20 +42,34 @@ type RequestArgsType = tuple[
] ]
async def async_create_command_cmd_args(
hass: HomeAssistant,
auth_data: UsmUserData | CommunityData,
target: UdpTransportTarget | Udp6TransportTarget,
) -> CommandArgsType:
"""Create command arguments.
The ObjectType needs to be created dynamically by the caller.
"""
engine = await async_get_snmp_engine(hass)
return (engine, auth_data, target, ContextData())
async def async_create_request_cmd_args( async def async_create_request_cmd_args(
hass: HomeAssistant, hass: HomeAssistant,
auth_data: UsmUserData | CommunityData, auth_data: UsmUserData | CommunityData,
target: UdpTransportTarget | Udp6TransportTarget, target: UdpTransportTarget | Udp6TransportTarget,
object_id: str, object_id: str,
) -> RequestArgsType: ) -> RequestArgsType:
"""Create request arguments.""" """Create request arguments.
return (
await async_get_snmp_engine(hass), The same ObjectType is used for all requests.
auth_data, """
target, engine, auth_data, target, context_data = await async_create_command_cmd_args(
ContextData(), hass, auth_data, target
ObjectType(ObjectIdentity(object_id)),
) )
object_type = ObjectType(ObjectIdentity(object_id))
return (engine, auth_data, target, context_data, object_type)
@singleton(DATA_SNMP_ENGINE) @singleton(DATA_SNMP_ENGINE)