Add title feature to notify entity platform (#116426)

* Add title feature to notify entity platform

* Add overload variants

* Remove overloads, update signatures

* Improve test coverage

* Apply suggestions from code review

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Do not use const

* fix typo

---------

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Jan Bouwhuis 2024-05-03 11:17:28 +02:00 committed by GitHub
parent ecdad19296
commit 84308c9e53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 133 additions and 23 deletions

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from homeassistant.components.notify import DOMAIN, NotifyEntity from homeassistant.components.notify import DOMAIN, NotifyEntity, NotifyEntityFeature
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.device_registry import DeviceInfo from homeassistant.helpers.device_registry import DeviceInfo
@ -33,12 +33,15 @@ class DemoNotifyEntity(NotifyEntity):
) -> None: ) -> None:
"""Initialize the Demo button entity.""" """Initialize the Demo button entity."""
self._attr_unique_id = unique_id self._attr_unique_id = unique_id
self._attr_supported_features = NotifyEntityFeature.TITLE
self._attr_device_info = DeviceInfo( self._attr_device_info = DeviceInfo(
identifiers={(DOMAIN, unique_id)}, identifiers={(DOMAIN, unique_id)},
name=device_name, name=device_name,
) )
async def async_send_message(self, message: str) -> None: async def async_send_message(self, message: str, title: str | None = None) -> None:
"""Send a message to a user.""" """Send a message to a user."""
event_notitifcation = {"message": message} event_notification = {"message": message}
self.hass.bus.async_fire(EVENT_NOTIFY, event_notitifcation) if title is not None:
event_notification["title"] = title
self.hass.bus.async_fire(EVENT_NOTIFY, event_notification)

View File

@ -85,6 +85,6 @@ class EcobeeNotifyEntity(EcobeeBaseEntity, NotifyEntity):
f"{self.thermostat["identifier"]}_notify_{thermostat_index}" f"{self.thermostat["identifier"]}_notify_{thermostat_index}"
) )
def send_message(self, message: str) -> None: def send_message(self, message: str, title: str | None = None) -> None:
"""Send a message.""" """Send a message."""
self.data.ecobee.send_message(self.thermostat_index, message) self.data.ecobee.send_message(self.thermostat_index, message)

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from homeassistant.components import persistent_notification from homeassistant.components import persistent_notification
from homeassistant.components.notify import NotifyEntity from homeassistant.components.notify import NotifyEntity, NotifyEntityFeature
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.device_registry import DeviceInfo from homeassistant.helpers.device_registry import DeviceInfo
@ -25,6 +25,12 @@ async def async_setup_entry(
device_name="MyBox", device_name="MyBox",
entity_name="Personal notifier", entity_name="Personal notifier",
), ),
DemoNotify(
unique_id="just_notify_me_title",
device_name="MyBox",
entity_name="Personal notifier with title",
supported_features=NotifyEntityFeature.TITLE,
),
] ]
) )
@ -40,15 +46,19 @@ class DemoNotify(NotifyEntity):
unique_id: str, unique_id: str,
device_name: str, device_name: str,
entity_name: str | None, entity_name: str | None,
supported_features: NotifyEntityFeature = NotifyEntityFeature(0),
) -> None: ) -> None:
"""Initialize the Demo button entity.""" """Initialize the Demo button entity."""
self._attr_unique_id = unique_id self._attr_unique_id = unique_id
self._attr_supported_features = supported_features
self._attr_device_info = DeviceInfo( self._attr_device_info = DeviceInfo(
identifiers={(DOMAIN, unique_id)}, identifiers={(DOMAIN, unique_id)},
name=device_name, name=device_name,
) )
self._attr_name = entity_name self._attr_name = entity_name
async def async_send_message(self, message: str) -> None: async def async_send_message(self, message: str, title: str | None = None) -> None:
"""Send out a persistent notification.""" """Send out a persistent notification."""
persistent_notification.async_create(self.hass, message, "Demo notification") persistent_notification.async_create(
self.hass, message, title or "Demo notification"
)

View File

@ -108,6 +108,6 @@ class KNXNotify(KnxEntity, NotifyEntity):
self._attr_entity_category = config.get(CONF_ENTITY_CATEGORY) self._attr_entity_category = config.get(CONF_ENTITY_CATEGORY)
self._attr_unique_id = str(self._device.remote_value.group_address) self._attr_unique_id = str(self._device.remote_value.group_address)
async def async_send_message(self, message: str) -> None: async def async_send_message(self, message: str, title: str | None = None) -> None:
"""Send a notification to knx bus.""" """Send a notification to knx bus."""
await self._device.set(message) await self._device.set(message)

View File

@ -83,7 +83,7 @@ class MqttNotify(MqttEntity, NotifyEntity):
async def _subscribe_topics(self) -> None: async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
async def async_send_message(self, message: str) -> None: async def async_send_message(self, message: str, title: str | None = None) -> None:
"""Send a message.""" """Send a message."""
payload = self._command_template(message) payload = self._command_template(message)
await self.async_publish( await self.async_publish(

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import timedelta from datetime import timedelta
from enum import IntFlag
from functools import cached_property, partial from functools import cached_property, partial
import logging import logging
from typing import Any, final, override from typing import Any, final, override
@ -58,6 +59,12 @@ PLATFORM_SCHEMA = vol.Schema(
) )
class NotifyEntityFeature(IntFlag):
"""Supported features of a notify entity."""
TITLE = 1
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the notify services.""" """Set up the notify services."""
@ -73,7 +80,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
component = hass.data[DOMAIN] = EntityComponent[NotifyEntity](_LOGGER, DOMAIN, hass) component = hass.data[DOMAIN] = EntityComponent[NotifyEntity](_LOGGER, DOMAIN, hass)
component.async_register_entity_service( component.async_register_entity_service(
SERVICE_SEND_MESSAGE, SERVICE_SEND_MESSAGE,
{vol.Required(ATTR_MESSAGE): cv.string}, {
vol.Required(ATTR_MESSAGE): cv.string,
vol.Optional(ATTR_TITLE): cv.string,
},
"_async_send_message", "_async_send_message",
) )
@ -128,6 +138,7 @@ class NotifyEntity(RestoreEntity):
"""Representation of a notify entity.""" """Representation of a notify entity."""
entity_description: NotifyEntityDescription entity_description: NotifyEntityDescription
_attr_supported_features: NotifyEntityFeature = NotifyEntityFeature(0)
_attr_should_poll = False _attr_should_poll = False
_attr_device_class: None _attr_device_class: None
_attr_state: None = None _attr_state: None = None
@ -162,10 +173,19 @@ class NotifyEntity(RestoreEntity):
self.async_write_ha_state() self.async_write_ha_state()
await self.async_send_message(**kwargs) await self.async_send_message(**kwargs)
def send_message(self, message: str) -> None: def send_message(self, message: str, title: str | None = None) -> None:
"""Send a message.""" """Send a message."""
raise NotImplementedError raise NotImplementedError
async def async_send_message(self, message: str) -> None: async def async_send_message(self, message: str, title: str | None = None) -> None:
"""Send a message.""" """Send a message."""
await self.hass.async_add_executor_job(partial(self.send_message, message)) kwargs: dict[str, Any] = {}
if (
title is not None
and self.supported_features
and self.supported_features & NotifyEntityFeature.TITLE
):
kwargs[ATTR_TITLE] = title
await self.hass.async_add_executor_job(
partial(self.send_message, message, **kwargs)
)

View File

@ -29,6 +29,13 @@ send_message:
required: true required: true
selector: selector:
text: text:
title:
required: false
selector:
text:
filter:
supported_features:
- notify.NotifyEntityFeature.TITLE
persistent_notification: persistent_notification:
fields: fields:

View File

@ -35,6 +35,10 @@
"message": { "message": {
"name": "Message", "name": "Message",
"description": "Your notification message." "description": "Your notification message."
},
"title": {
"name": "Title",
"description": "Title for your notification message."
} }
} }
}, },

View File

@ -98,6 +98,7 @@ def _entity_features() -> dict[str, type[IntFlag]]:
from homeassistant.components.light import LightEntityFeature from homeassistant.components.light import LightEntityFeature
from homeassistant.components.lock import LockEntityFeature from homeassistant.components.lock import LockEntityFeature
from homeassistant.components.media_player import MediaPlayerEntityFeature from homeassistant.components.media_player import MediaPlayerEntityFeature
from homeassistant.components.notify import NotifyEntityFeature
from homeassistant.components.remote import RemoteEntityFeature from homeassistant.components.remote import RemoteEntityFeature
from homeassistant.components.siren import SirenEntityFeature from homeassistant.components.siren import SirenEntityFeature
from homeassistant.components.todo import TodoListEntityFeature from homeassistant.components.todo import TodoListEntityFeature
@ -119,6 +120,7 @@ def _entity_features() -> dict[str, type[IntFlag]]:
"LightEntityFeature": LightEntityFeature, "LightEntityFeature": LightEntityFeature,
"LockEntityFeature": LockEntityFeature, "LockEntityFeature": LockEntityFeature,
"MediaPlayerEntityFeature": MediaPlayerEntityFeature, "MediaPlayerEntityFeature": MediaPlayerEntityFeature,
"NotifyEntityFeature": NotifyEntityFeature,
"RemoteEntityFeature": RemoteEntityFeature, "RemoteEntityFeature": RemoteEntityFeature,
"SirenEntityFeature": SirenEntityFeature, "SirenEntityFeature": SirenEntityFeature,
"TodoListEntityFeature": TodoListEntityFeature, "TodoListEntityFeature": TodoListEntityFeature,

View File

@ -69,7 +69,17 @@ async def test_sending_message(hass: HomeAssistant, events: list[Event]) -> None
await hass.services.async_call(notify.DOMAIN, notify.SERVICE_SEND_MESSAGE, data) await hass.services.async_call(notify.DOMAIN, notify.SERVICE_SEND_MESSAGE, data)
await hass.async_block_till_done() await hass.async_block_till_done()
last_event = events[-1] last_event = events[-1]
assert last_event.data[notify.ATTR_MESSAGE] == "Test message" assert last_event.data == {notify.ATTR_MESSAGE: "Test message"}
data[notify.ATTR_TITLE] = "My title"
# Test with Title
await hass.services.async_call(notify.DOMAIN, notify.SERVICE_SEND_MESSAGE, data)
await hass.async_block_till_done()
last_event = events[-1]
assert last_event.data == {
notify.ATTR_MESSAGE: "Test message",
notify.ATTR_TITLE: "My title",
}
async def test_calling_notify_from_script_loaded_from_yaml( async def test_calling_notify_from_script_loaded_from_yaml(

View File

@ -12,6 +12,7 @@ from homeassistant.components.notify import (
SERVICE_SEND_MESSAGE, SERVICE_SEND_MESSAGE,
NotifyEntity, NotifyEntity,
NotifyEntityDescription, NotifyEntityDescription,
NotifyEntityFeature,
) )
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN, Platform from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN, Platform
@ -27,7 +28,8 @@ from tests.common import (
setup_test_component_platform, setup_test_component_platform,
) )
TEST_KWARGS = {"message": "Test message"} TEST_KWARGS = {notify.ATTR_MESSAGE: "Test message"}
TEST_KWARGS_TITLE = {notify.ATTR_MESSAGE: "Test message", notify.ATTR_TITLE: "My title"}
class MockNotifyEntity(MockEntity, NotifyEntity): class MockNotifyEntity(MockEntity, NotifyEntity):
@ -35,9 +37,9 @@ class MockNotifyEntity(MockEntity, NotifyEntity):
send_message_mock_calls = MagicMock() send_message_mock_calls = MagicMock()
async def async_send_message(self, message: str) -> None: async def async_send_message(self, message: str, title: str | None = None) -> None:
"""Send a notification message.""" """Send a notification message."""
self.send_message_mock_calls(message=message) self.send_message_mock_calls(message, title=title)
class MockNotifyEntityNonAsync(MockEntity, NotifyEntity): class MockNotifyEntityNonAsync(MockEntity, NotifyEntity):
@ -45,9 +47,9 @@ class MockNotifyEntityNonAsync(MockEntity, NotifyEntity):
send_message_mock_calls = MagicMock() send_message_mock_calls = MagicMock()
def send_message(self, message: str) -> None: def send_message(self, message: str, title: str | None = None) -> None:
"""Send a notification message.""" """Send a notification message."""
self.send_message_mock_calls(message=message) self.send_message_mock_calls(message, title=title)
async def help_async_setup_entry_init( async def help_async_setup_entry_init(
@ -132,6 +134,58 @@ async def test_send_message_service(
assert await hass.config_entries.async_unload(config_entry.entry_id) assert await hass.config_entries.async_unload(config_entry.entry_id)
@pytest.mark.parametrize(
"entity",
[
MockNotifyEntityNonAsync(
name="test",
entity_id="notify.test",
supported_features=NotifyEntityFeature.TITLE,
),
MockNotifyEntity(
name="test",
entity_id="notify.test",
supported_features=NotifyEntityFeature.TITLE,
),
],
ids=["non_async", "async"],
)
async def test_send_message_service_with_title(
hass: HomeAssistant, config_flow_fixture: None, entity: NotifyEntity
) -> None:
"""Test send_message service."""
config_entry = MockConfigEntry(domain="test")
config_entry.add_to_hass(hass)
mock_integration(
hass,
MockModule(
"test",
async_setup_entry=help_async_setup_entry_init,
async_unload_entry=help_async_unload_entry,
),
)
setup_test_component_platform(hass, DOMAIN, [entity], from_config_entry=True)
assert await hass.config_entries.async_setup(config_entry.entry_id)
state = hass.states.get("notify.test")
assert state.state is STATE_UNKNOWN
await hass.services.async_call(
DOMAIN,
SERVICE_SEND_MESSAGE,
copy.deepcopy(TEST_KWARGS_TITLE) | {"entity_id": "notify.test"},
blocking=True,
)
await hass.async_block_till_done()
entity.send_message_mock_calls.assert_called_once_with(
TEST_KWARGS_TITLE[notify.ATTR_MESSAGE],
title=TEST_KWARGS_TITLE[notify.ATTR_TITLE],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("state", "init_state"), ("state", "init_state"),
[ [
@ -202,12 +256,12 @@ async def test_name(hass: HomeAssistant, config_flow_fixture: None) -> None:
state = hass.states.get(entity1.entity_id) state = hass.states.get(entity1.entity_id)
assert state assert state
assert state.attributes == {} assert state.attributes == {"supported_features": NotifyEntityFeature(0)}
state = hass.states.get(entity2.entity_id) state = hass.states.get(entity2.entity_id)
assert state assert state
assert state.attributes == {} assert state.attributes == {"supported_features": NotifyEntityFeature(0)}
state = hass.states.get(entity3.entity_id) state = hass.states.get(entity3.entity_id)
assert state assert state
assert state.attributes == {} assert state.attributes == {"supported_features": NotifyEntityFeature(0)}