diff --git a/homeassistant/components/unifi/entity.py b/homeassistant/components/unifi/entity.py index 05ad2f56a8c..28a7b557b16 100644 --- a/homeassistant/components/unifi/entity.py +++ b/homeassistant/components/unifi/entity.py @@ -53,12 +53,12 @@ def async_wlan_available_fn(controller: UniFiController, obj_id: str) -> bool: @callback -def async_device_device_info_fn(api: aiounifi.Controller, obj_id: str) -> DeviceInfo: +def async_device_device_info_fn(controller: UniFiController, obj_id: str) -> DeviceInfo: """Create device registry entry for device.""" if "_" in obj_id: # Sub device (outlet or port) obj_id = obj_id.partition("_")[0] - device = api.devices[obj_id] + device = controller.api.devices[obj_id] return DeviceInfo( connections={(CONNECTION_NETWORK_MAC, device.mac)}, manufacturer=ATTR_MANUFACTURER, @@ -70,9 +70,9 @@ def async_device_device_info_fn(api: aiounifi.Controller, obj_id: str) -> Device @callback -def async_wlan_device_info_fn(api: aiounifi.Controller, obj_id: str) -> DeviceInfo: +def async_wlan_device_info_fn(controller: UniFiController, obj_id: str) -> DeviceInfo: """Create device registry entry for WLAN.""" - wlan = api.wlans[obj_id] + wlan = controller.api.wlans[obj_id] return DeviceInfo( entry_type=DeviceEntryType.SERVICE, identifiers={(DOMAIN, wlan.id)}, @@ -83,9 +83,9 @@ def async_wlan_device_info_fn(api: aiounifi.Controller, obj_id: str) -> DeviceIn @callback -def async_client_device_info_fn(api: aiounifi.Controller, obj_id: str) -> DeviceInfo: +def async_client_device_info_fn(controller: UniFiController, obj_id: str) -> DeviceInfo: """Create device registry entry for client.""" - client = api.clients[obj_id] + client = controller.api.clients[obj_id] return DeviceInfo( connections={(CONNECTION_NETWORK_MAC, obj_id)}, default_manufacturer=client.oui, @@ -100,7 +100,7 @@ class UnifiDescription(Generic[HandlerT, ApiItemT]): allowed_fn: Callable[[UniFiController, str], bool] api_handler_fn: Callable[[aiounifi.Controller], HandlerT] available_fn: Callable[[UniFiController, str], bool] - device_info_fn: Callable[[aiounifi.Controller, str], DeviceInfo | None] + device_info_fn: Callable[[UniFiController, str], DeviceInfo | None] event_is_on: tuple[EventKey, ...] | None event_to_subscribe: tuple[EventKey, ...] | None name_fn: Callable[[ApiItemT], str | None] @@ -137,7 +137,7 @@ class UnifiEntity(Entity, Generic[HandlerT, ApiItemT]): self._removed = False self._attr_available = description.available_fn(controller, obj_id) - self._attr_device_info = description.device_info_fn(controller.api, obj_id) + self._attr_device_info = description.device_info_fn(controller, obj_id) self._attr_should_poll = description.should_poll self._attr_unique_id = description.unique_id_fn(controller, obj_id) diff --git a/homeassistant/components/unifi/switch.py b/homeassistant/components/unifi/switch.py index e2b4dda3912..046aa3a1abd 100644 --- a/homeassistant/components/unifi/switch.py +++ b/homeassistant/components/unifi/switch.py @@ -17,6 +17,7 @@ from aiounifi.interfaces.api_handlers import ItemEvent from aiounifi.interfaces.clients import Clients from aiounifi.interfaces.dpi_restriction_groups import DPIRestrictionGroups from aiounifi.interfaces.outlets import Outlets +from aiounifi.interfaces.port_forwarding import PortForwarding from aiounifi.interfaces.ports import Ports from aiounifi.interfaces.wlans import Wlans from aiounifi.models.api import ApiItemT @@ -30,6 +31,7 @@ from aiounifi.models.dpi_restriction_group import DPIRestrictionGroup from aiounifi.models.event import Event, EventKey from aiounifi.models.outlet import Outlet from aiounifi.models.port import Port +from aiounifi.models.port_forward import PortForward, PortForwardEnableRequest from aiounifi.models.wlan import Wlan, WlanEnableRequest from homeassistant.components.switch import ( @@ -75,7 +77,9 @@ def async_dpi_group_is_on_fn( @callback -def async_dpi_group_device_info_fn(api: aiounifi.Controller, obj_id: str) -> DeviceInfo: +def async_dpi_group_device_info_fn( + controller: UniFiController, obj_id: str +) -> DeviceInfo: """Create device registry entry for DPI group.""" return DeviceInfo( entry_type=DeviceEntryType.SERVICE, @@ -86,6 +90,22 @@ def async_dpi_group_device_info_fn(api: aiounifi.Controller, obj_id: str) -> Dev ) +@callback +def async_port_forward_device_info_fn( + controller: UniFiController, obj_id: str +) -> DeviceInfo: + """Create device registry entry for port forward.""" + unique_id = controller.config_entry.unique_id + assert unique_id is not None + return DeviceInfo( + entry_type=DeviceEntryType.SERVICE, + identifiers={(DOMAIN, unique_id)}, + manufacturer=ATTR_MANUFACTURER, + model="UniFi Network", + name="UniFi Network", + ) + + async def async_block_client_control_fn( api: aiounifi.Controller, obj_id: str, target: bool ) -> None: @@ -136,6 +156,14 @@ async def async_poe_port_control_fn( await api.request(DeviceSetPoePortModeRequest.create(device, int(index), state)) +async def async_port_forward_control_fn( + api: aiounifi.Controller, obj_id: str, target: bool +) -> None: + """Control port forward state.""" + port_forward = api.port_forwarding[obj_id] + await api.request(PortForwardEnableRequest.create(port_forward, target)) + + async def async_wlan_control_fn( api: aiounifi.Controller, obj_id: str, target: bool ) -> None: @@ -222,6 +250,26 @@ ENTITY_DESCRIPTIONS: tuple[UnifiSwitchEntityDescription, ...] = ( supported_fn=async_outlet_supports_switching_fn, unique_id_fn=lambda controller, obj_id: f"{obj_id.split('_', 1)[0]}-outlet-{obj_id.split('_', 1)[1]}", ), + UnifiSwitchEntityDescription[PortForwarding, PortForward]( + key="Port forward control", + device_class=SwitchDeviceClass.SWITCH, + entity_category=EntityCategory.CONFIG, + has_entity_name=True, + icon="mdi:upload-network", + allowed_fn=lambda controller, obj_id: True, + api_handler_fn=lambda api: api.port_forwarding, + available_fn=lambda controller, obj_id: controller.available, + control_fn=async_port_forward_control_fn, + device_info_fn=async_port_forward_device_info_fn, + event_is_on=None, + event_to_subscribe=None, + is_on_fn=lambda controller, port_forward: port_forward.enabled, + name_fn=lambda port_forward: f"{port_forward.name}", + object_fn=lambda api, obj_id: api.port_forwarding[obj_id], + should_poll=False, + supported_fn=lambda controller, obj_id: True, + unique_id_fn=lambda controller, obj_id: f"port_forward-{obj_id}", + ), UnifiSwitchEntityDescription[Ports, Port]( key="PoE port control", device_class=SwitchDeviceClass.OUTLET, diff --git a/tests/components/unifi/test_switch.py b/tests/components/unifi/test_switch.py index c091fc5cc59..8e3e215e717 100644 --- a/tests/components/unifi/test_switch.py +++ b/tests/components/unifi/test_switch.py @@ -1518,3 +1518,90 @@ async def test_wlan_switches( mock_unifi_websocket(state=WebsocketState.RUNNING) await hass.async_block_till_done() assert hass.states.get("switch.ssid_1").state == STATE_OFF + + +async def test_port_forwarding_switches( + hass: HomeAssistant, aioclient_mock: AiohttpClientMocker, mock_unifi_websocket +) -> None: + """Test control of UniFi port forwarding.""" + _data = { + "_id": "5a32aa4ee4b0412345678911", + "dst_port": "12345", + "enabled": True, + "fwd_port": "23456", + "fwd": "10.0.0.2", + "name": "plex", + "pfwd_interface": "wan", + "proto": "tcp_udp", + "site_id": "5a32aa4ee4b0412345678910", + "src": "any", + } + config_entry = await setup_unifi_integration( + hass, aioclient_mock, port_forward_response=[_data.copy()] + ) + controller = hass.data[UNIFI_DOMAIN][config_entry.entry_id] + + assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 1 + + ent_reg = er.async_get(hass) + ent_reg_entry = ent_reg.async_get("switch.unifi_network_plex") + assert ent_reg_entry.unique_id == "port_forward-5a32aa4ee4b0412345678911" + assert ent_reg_entry.entity_category is EntityCategory.CONFIG + + # Validate state object + switch_1 = hass.states.get("switch.unifi_network_plex") + assert switch_1 is not None + assert switch_1.state == STATE_ON + assert switch_1.attributes.get(ATTR_DEVICE_CLASS) == SwitchDeviceClass.SWITCH + + # Update state object + data = _data.copy() + data["enabled"] = False + mock_unifi_websocket(message=MessageKey.PORT_FORWARD_UPDATED, data=data) + await hass.async_block_till_done() + assert hass.states.get("switch.unifi_network_plex").state == STATE_OFF + + # Disable port forward + aioclient_mock.clear_requests() + aioclient_mock.put( + f"https://{controller.host}:1234/api/s/{controller.site}" + + f"/rest/portforward/{data['_id']}", + ) + + await hass.services.async_call( + SWITCH_DOMAIN, + "turn_off", + {"entity_id": "switch.unifi_network_plex"}, + blocking=True, + ) + assert aioclient_mock.call_count == 1 + data = _data.copy() + data["enabled"] = False + assert aioclient_mock.mock_calls[0][2] == data + + # Enable port forward + await hass.services.async_call( + SWITCH_DOMAIN, + "turn_on", + {"entity_id": "switch.unifi_network_plex"}, + blocking=True, + ) + assert aioclient_mock.call_count == 2 + assert aioclient_mock.mock_calls[1][2] == _data + + # Availability signalling + + # Controller disconnects + mock_unifi_websocket(state=WebsocketState.DISCONNECTED) + await hass.async_block_till_done() + assert hass.states.get("switch.unifi_network_plex").state == STATE_UNAVAILABLE + + # Controller reconnects + mock_unifi_websocket(state=WebsocketState.RUNNING) + await hass.async_block_till_done() + assert hass.states.get("switch.unifi_network_plex").state == STATE_OFF + + # Remove entity on deleted message + mock_unifi_websocket(message=MessageKey.PORT_FORWARD_DELETED, data=_data) + await hass.async_block_till_done() + assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 0