Store notifications in component. Add ws endpoint for fetching. (#16503)

* Store notifications in component. Add ws endpoint for fetching.

* Comments
This commit is contained in:
Jerad Meisner 2018-09-11 02:39:30 -07:00 committed by Paulus Schoutsen
parent 20f6cb7cc7
commit 50fb59477a
3 changed files with 203 additions and 11 deletions

View File

@ -10,7 +10,6 @@ from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.components import persistent_notification
from homeassistant.config import load_yaml_config_file from homeassistant.config import load_yaml_config_file
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
@ -92,9 +91,10 @@ async def process_wrong_login(request):
msg = ('Login attempt or request with invalid authentication ' msg = ('Login attempt or request with invalid authentication '
'from {}'.format(remote_addr)) 'from {}'.format(remote_addr))
_LOGGER.warning(msg) _LOGGER.warning(msg)
persistent_notification.async_create(
request.app['hass'], msg, 'Login attempt failed', hass = request.app['hass']
NOTIFICATION_ID_LOGIN) hass.components.persistent_notification.async_create(
msg, 'Login attempt failed', NOTIFICATION_ID_LOGIN)
# Check if ban middleware is loaded # Check if ban middleware is loaded
if (KEY_BANNED_IPS not in request.app or if (KEY_BANNED_IPS not in request.app or
@ -108,15 +108,13 @@ async def process_wrong_login(request):
new_ban = IpBan(remote_addr) new_ban = IpBan(remote_addr)
request.app[KEY_BANNED_IPS].append(new_ban) request.app[KEY_BANNED_IPS].append(new_ban)
hass = request.app['hass']
await hass.async_add_job( await hass.async_add_job(
update_ip_bans_config, hass.config.path(IP_BANS_FILE), new_ban) update_ip_bans_config, hass.config.path(IP_BANS_FILE), new_ban)
_LOGGER.warning( _LOGGER.warning(
"Banned IP %s for too many login attempts", remote_addr) "Banned IP %s for too many login attempts", remote_addr)
persistent_notification.async_create( hass.components.persistent_notification.async_create(
hass,
'Too many login attempts from {}'.format(remote_addr), 'Too many login attempts from {}'.format(remote_addr),
'Banning IP address', NOTIFICATION_ID_BAN) 'Banning IP address', NOTIFICATION_ID_BAN)

View File

@ -6,10 +6,12 @@ https://home-assistant.io/components/persistent_notification/
""" """
import asyncio import asyncio
import logging import logging
from collections import OrderedDict
from typing import Awaitable from typing import Awaitable
import voluptuous as vol import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.core import callback, HomeAssistant from homeassistant.core import callback, HomeAssistant
from homeassistant.exceptions import TemplateError from homeassistant.exceptions import TemplateError
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
@ -20,13 +22,17 @@ from homeassistant.util import slugify
ATTR_MESSAGE = 'message' ATTR_MESSAGE = 'message'
ATTR_NOTIFICATION_ID = 'notification_id' ATTR_NOTIFICATION_ID = 'notification_id'
ATTR_TITLE = 'title' ATTR_TITLE = 'title'
ATTR_STATUS = 'status'
DOMAIN = 'persistent_notification' DOMAIN = 'persistent_notification'
ENTITY_ID_FORMAT = DOMAIN + '.{}' ENTITY_ID_FORMAT = DOMAIN + '.{}'
EVENT_PERSISTENT_NOTIFICATIONS_UPDATED = 'persistent_notifications_updated'
SERVICE_CREATE = 'create' SERVICE_CREATE = 'create'
SERVICE_DISMISS = 'dismiss' SERVICE_DISMISS = 'dismiss'
SERVICE_MARK_READ = 'mark_read'
SCHEMA_SERVICE_CREATE = vol.Schema({ SCHEMA_SERVICE_CREATE = vol.Schema({
vol.Required(ATTR_MESSAGE): cv.template, vol.Required(ATTR_MESSAGE): cv.template,
@ -38,11 +44,21 @@ SCHEMA_SERVICE_DISMISS = vol.Schema({
vol.Required(ATTR_NOTIFICATION_ID): cv.string, vol.Required(ATTR_NOTIFICATION_ID): cv.string,
}) })
SCHEMA_SERVICE_MARK_READ = vol.Schema({
vol.Required(ATTR_NOTIFICATION_ID): cv.string,
})
DEFAULT_OBJECT_ID = 'notification' DEFAULT_OBJECT_ID = 'notification'
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
STATE = 'notifying' STATE = 'notifying'
STATUS_UNREAD = 'unread'
STATUS_READ = 'read'
WS_TYPE_GET_NOTIFICATIONS = 'persistent_notification/get'
SCHEMA_WS_GET = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): WS_TYPE_GET_NOTIFICATIONS,
})
@bind_hass @bind_hass
@ -76,7 +92,7 @@ def async_create(hass: HomeAssistant, message: str, title: str = None,
@callback @callback
@bind_hass @bind_hass
def async_dismiss(hass, notification_id): def async_dismiss(hass: HomeAssistant, notification_id: str) -> None:
"""Remove a notification.""" """Remove a notification."""
data = {ATTR_NOTIFICATION_ID: notification_id} data = {ATTR_NOTIFICATION_ID: notification_id}
@ -86,6 +102,9 @@ def async_dismiss(hass, notification_id):
@asyncio.coroutine @asyncio.coroutine
def async_setup(hass: HomeAssistant, config: dict) -> Awaitable[bool]: def async_setup(hass: HomeAssistant, config: dict) -> Awaitable[bool]:
"""Set up the persistent notification component.""" """Set up the persistent notification component."""
persistent_notifications = OrderedDict()
hass.data[DOMAIN] = {'notifications': persistent_notifications}
@callback @callback
def create_service(call): def create_service(call):
"""Handle a create notification service call.""" """Handle a create notification service call."""
@ -98,6 +117,8 @@ def async_setup(hass: HomeAssistant, config: dict) -> Awaitable[bool]:
else: else:
entity_id = async_generate_entity_id( entity_id = async_generate_entity_id(
ENTITY_ID_FORMAT, DEFAULT_OBJECT_ID, hass=hass) ENTITY_ID_FORMAT, DEFAULT_OBJECT_ID, hass=hass)
notification_id = entity_id.split('.')[1]
attr = {} attr = {}
if title is not None: if title is not None:
try: try:
@ -120,18 +141,72 @@ def async_setup(hass: HomeAssistant, config: dict) -> Awaitable[bool]:
hass.states.async_set(entity_id, STATE, attr) hass.states.async_set(entity_id, STATE, attr)
# Store notification and fire event
# This will eventually replace state machine storage
persistent_notifications[entity_id] = {
ATTR_MESSAGE: message,
ATTR_NOTIFICATION_ID: notification_id,
ATTR_STATUS: STATUS_UNREAD,
ATTR_TITLE: title,
}
hass.bus.async_fire(EVENT_PERSISTENT_NOTIFICATIONS_UPDATED)
@callback @callback
def dismiss_service(call): def dismiss_service(call):
"""Handle the dismiss notification service call.""" """Handle the dismiss notification service call."""
notification_id = call.data.get(ATTR_NOTIFICATION_ID) notification_id = call.data.get(ATTR_NOTIFICATION_ID)
entity_id = ENTITY_ID_FORMAT.format(slugify(notification_id)) entity_id = ENTITY_ID_FORMAT.format(slugify(notification_id))
if entity_id not in persistent_notifications:
return
hass.states.async_remove(entity_id) hass.states.async_remove(entity_id)
del persistent_notifications[entity_id]
hass.bus.async_fire(EVENT_PERSISTENT_NOTIFICATIONS_UPDATED)
@callback
def mark_read_service(call):
"""Handle the mark_read notification service call."""
notification_id = call.data.get(ATTR_NOTIFICATION_ID)
entity_id = ENTITY_ID_FORMAT.format(slugify(notification_id))
if entity_id not in persistent_notifications:
_LOGGER.error('Marking persistent_notification read failed: '
'Notification ID %s not found.', notification_id)
return
persistent_notifications[entity_id][ATTR_STATUS] = STATUS_READ
hass.bus.async_fire(EVENT_PERSISTENT_NOTIFICATIONS_UPDATED)
hass.services.async_register(DOMAIN, SERVICE_CREATE, create_service, hass.services.async_register(DOMAIN, SERVICE_CREATE, create_service,
SCHEMA_SERVICE_CREATE) SCHEMA_SERVICE_CREATE)
hass.services.async_register(DOMAIN, SERVICE_DISMISS, dismiss_service, hass.services.async_register(DOMAIN, SERVICE_DISMISS, dismiss_service,
SCHEMA_SERVICE_DISMISS) SCHEMA_SERVICE_DISMISS)
hass.services.async_register(DOMAIN, SERVICE_MARK_READ, mark_read_service,
SCHEMA_SERVICE_MARK_READ)
hass.components.websocket_api.async_register_command(
WS_TYPE_GET_NOTIFICATIONS, websocket_get_notifications,
SCHEMA_WS_GET
)
return True return True
@callback
def websocket_get_notifications(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg):
"""Return a list of persistent_notifications."""
connection.to_write.put_nowait(
websocket_api.result_message(msg['id'], [
{
key: data[key] for key in (ATTR_NOTIFICATION_ID, ATTR_MESSAGE,
ATTR_STATUS, ATTR_TITLE)
}
for data in hass.data[DOMAIN]['notifications'].values()
])
)

View File

@ -1,5 +1,6 @@
"""The tests for the persistent notification component.""" """The tests for the persistent notification component."""
from homeassistant.setup import setup_component from homeassistant.components import websocket_api
from homeassistant.setup import setup_component, async_setup_component
import homeassistant.components.persistent_notification as pn import homeassistant.components.persistent_notification as pn
from tests.common import get_test_home_assistant from tests.common import get_test_home_assistant
@ -19,7 +20,9 @@ class TestPersistentNotification:
def test_create(self): def test_create(self):
"""Test creating notification without title or notification id.""" """Test creating notification without title or notification id."""
notifications = self.hass.data[pn.DOMAIN]['notifications']
assert len(self.hass.states.entity_ids(pn.DOMAIN)) == 0 assert len(self.hass.states.entity_ids(pn.DOMAIN)) == 0
assert len(notifications) == 0
pn.create(self.hass, 'Hello World {{ 1 + 1 }}', pn.create(self.hass, 'Hello World {{ 1 + 1 }}',
title='{{ 1 + 1 }} beers') title='{{ 1 + 1 }} beers')
@ -27,54 +30,170 @@ class TestPersistentNotification:
entity_ids = self.hass.states.entity_ids(pn.DOMAIN) entity_ids = self.hass.states.entity_ids(pn.DOMAIN)
assert len(entity_ids) == 1 assert len(entity_ids) == 1
assert len(notifications) == 1
state = self.hass.states.get(entity_ids[0]) state = self.hass.states.get(entity_ids[0])
assert state.state == pn.STATE assert state.state == pn.STATE
assert state.attributes.get('message') == 'Hello World 2' assert state.attributes.get('message') == 'Hello World 2'
assert state.attributes.get('title') == '2 beers' assert state.attributes.get('title') == '2 beers'
notification = notifications.get(entity_ids[0])
assert notification['status'] == pn.STATUS_UNREAD
assert notification['message'] == 'Hello World 2'
assert notification['title'] == '2 beers'
notifications.clear()
def test_create_notification_id(self): def test_create_notification_id(self):
"""Ensure overwrites existing notification with same id.""" """Ensure overwrites existing notification with same id."""
notifications = self.hass.data[pn.DOMAIN]['notifications']
assert len(self.hass.states.entity_ids(pn.DOMAIN)) == 0 assert len(self.hass.states.entity_ids(pn.DOMAIN)) == 0
assert len(notifications) == 0
pn.create(self.hass, 'test', notification_id='Beer 2') pn.create(self.hass, 'test', notification_id='Beer 2')
self.hass.block_till_done() self.hass.block_till_done()
assert len(self.hass.states.entity_ids()) == 1 assert len(self.hass.states.entity_ids()) == 1
state = self.hass.states.get('persistent_notification.beer_2') assert len(notifications) == 1
entity_id = 'persistent_notification.beer_2'
state = self.hass.states.get(entity_id)
assert state.attributes.get('message') == 'test' assert state.attributes.get('message') == 'test'
notification = notifications.get(entity_id)
assert notification['message'] == 'test'
assert notification['title'] is None
pn.create(self.hass, 'test 2', notification_id='Beer 2') pn.create(self.hass, 'test 2', notification_id='Beer 2')
self.hass.block_till_done() self.hass.block_till_done()
# We should have overwritten old one # We should have overwritten old one
assert len(self.hass.states.entity_ids()) == 1 assert len(self.hass.states.entity_ids()) == 1
state = self.hass.states.get('persistent_notification.beer_2') state = self.hass.states.get(entity_id)
assert state.attributes.get('message') == 'test 2' assert state.attributes.get('message') == 'test 2'
notification = notifications.get(entity_id)
assert notification['message'] == 'test 2'
notifications.clear()
def test_create_template_error(self): def test_create_template_error(self):
"""Ensure we output templates if contain error.""" """Ensure we output templates if contain error."""
notifications = self.hass.data[pn.DOMAIN]['notifications']
assert len(self.hass.states.entity_ids(pn.DOMAIN)) == 0 assert len(self.hass.states.entity_ids(pn.DOMAIN)) == 0
assert len(notifications) == 0
pn.create(self.hass, '{{ message + 1 }}', '{{ title + 1 }}') pn.create(self.hass, '{{ message + 1 }}', '{{ title + 1 }}')
self.hass.block_till_done() self.hass.block_till_done()
entity_ids = self.hass.states.entity_ids(pn.DOMAIN) entity_ids = self.hass.states.entity_ids(pn.DOMAIN)
assert len(entity_ids) == 1 assert len(entity_ids) == 1
assert len(notifications) == 1
state = self.hass.states.get(entity_ids[0]) state = self.hass.states.get(entity_ids[0])
assert state.attributes.get('message') == '{{ message + 1 }}' assert state.attributes.get('message') == '{{ message + 1 }}'
assert state.attributes.get('title') == '{{ title + 1 }}' assert state.attributes.get('title') == '{{ title + 1 }}'
notification = notifications.get(entity_ids[0])
assert notification['message'] == '{{ message + 1 }}'
assert notification['title'] == '{{ title + 1 }}'
notifications.clear()
def test_dismiss_notification(self): def test_dismiss_notification(self):
"""Ensure removal of specific notification.""" """Ensure removal of specific notification."""
notifications = self.hass.data[pn.DOMAIN]['notifications']
assert len(self.hass.states.entity_ids(pn.DOMAIN)) == 0 assert len(self.hass.states.entity_ids(pn.DOMAIN)) == 0
assert len(notifications) == 0
pn.create(self.hass, 'test', notification_id='Beer 2') pn.create(self.hass, 'test', notification_id='Beer 2')
self.hass.block_till_done() self.hass.block_till_done()
assert len(self.hass.states.entity_ids(pn.DOMAIN)) == 1 assert len(self.hass.states.entity_ids(pn.DOMAIN)) == 1
assert len(notifications) == 1
pn.dismiss(self.hass, notification_id='Beer 2') pn.dismiss(self.hass, notification_id='Beer 2')
self.hass.block_till_done() self.hass.block_till_done()
assert len(self.hass.states.entity_ids(pn.DOMAIN)) == 0 assert len(self.hass.states.entity_ids(pn.DOMAIN)) == 0
assert len(notifications) == 0
notifications.clear()
def test_mark_read(self):
"""Ensure notification is marked as Read."""
notifications = self.hass.data[pn.DOMAIN]['notifications']
assert len(notifications) == 0
pn.create(self.hass, 'test', notification_id='Beer 2')
self.hass.block_till_done()
entity_id = 'persistent_notification.beer_2'
assert len(notifications) == 1
notification = notifications.get(entity_id)
assert notification['status'] == pn.STATUS_UNREAD
self.hass.services.call(pn.DOMAIN, pn.SERVICE_MARK_READ, {
'notification_id': 'Beer 2'
})
self.hass.block_till_done()
assert len(notifications) == 1
notification = notifications.get(entity_id)
assert notification['status'] == pn.STATUS_READ
notifications.clear()
async def test_ws_get_notifications(hass, hass_ws_client):
"""Test websocket endpoint for retrieving persistent notifications."""
await async_setup_component(hass, pn.DOMAIN, {})
client = await hass_ws_client(hass)
await client.send_json({
'id': 5,
'type': 'persistent_notification/get'
})
msg = await client.receive_json()
assert msg['id'] == 5
assert msg['type'] == websocket_api.TYPE_RESULT
assert msg['success']
notifications = msg['result']
assert len(notifications) == 0
# Create
hass.components.persistent_notification.async_create(
'test', notification_id='Beer 2')
await client.send_json({
'id': 6,
'type': 'persistent_notification/get'
})
msg = await client.receive_json()
assert msg['id'] == 6
assert msg['type'] == websocket_api.TYPE_RESULT
assert msg['success']
notifications = msg['result']
assert len(notifications) == 1
notification = notifications[0]
assert notification['notification_id'] == 'Beer 2'
assert notification['message'] == 'test'
assert notification['title'] is None
assert notification['status'] == pn.STATUS_UNREAD
# Mark Read
await hass.services.async_call(pn.DOMAIN, pn.SERVICE_MARK_READ, {
'notification_id': 'Beer 2'
})
await client.send_json({
'id': 7,
'type': 'persistent_notification/get'
})
msg = await client.receive_json()
notifications = msg['result']
assert len(notifications) == 1
assert notifications[0]['status'] == pn.STATUS_READ
# Dismiss
hass.components.persistent_notification.async_dismiss('Beer 2')
await client.send_json({
'id': 8,
'type': 'persistent_notification/get'
})
msg = await client.receive_json()
notifications = msg['result']
assert len(notifications) == 0