diff --git a/homeassistant/components/notify/html5.py b/homeassistant/components/notify/html5.py
index 2314722a2ab..a979ab5fb2f 100644
--- a/homeassistant/components/notify/html5.py
+++ b/homeassistant/components/notify/html5.py
@@ -169,15 +169,35 @@ class HTML5PushRegistrationView(HomeAssistantView):
return self.json_message(
humanize_error(data, ex), HTTP_BAD_REQUEST)
- name = ensure_unique_string('unnamed device', self.registrations)
+ name = self.find_registration_name(data)
+ previous_registration = self.registrations.get(name)
self.registrations[name] = data
- if not save_json(self.json_path, self.registrations):
+ try:
+ hass = request.app['hass']
+
+ yield from hass.async_add_job(save_json, self.json_path,
+ self.registrations)
+ return self.json_message(
+ 'Push notification subscriber registered.')
+ except HomeAssistantError:
+ if previous_registration is not None:
+ self.registrations[name] = previous_registration
+ else:
+ self.registrations.pop(name)
+
return self.json_message(
'Error saving registration.', HTTP_INTERNAL_SERVER_ERROR)
- return self.json_message('Push notification subscriber registered.')
+ def find_registration_name(self, data):
+ """Find a registration name matching data or generate a unique one."""
+ endpoint = data.get(ATTR_SUBSCRIPTION).get(ATTR_ENDPOINT)
+ for key, registration in self.registrations.items():
+ subscription = registration.get(ATTR_SUBSCRIPTION)
+ if subscription.get(ATTR_ENDPOINT) == endpoint:
+ return key
+ return ensure_unique_string('unnamed device', self.registrations)
@asyncio.coroutine
def delete(self, request):
@@ -202,7 +222,12 @@ class HTML5PushRegistrationView(HomeAssistantView):
reg = self.registrations.pop(found)
- if not save_json(self.json_path, self.registrations):
+ try:
+ hass = request.app['hass']
+
+ yield from hass.async_add_job(save_json, self.json_path,
+ self.registrations)
+ except HomeAssistantError:
self.registrations[found] = reg
return self.json_message(
'Error saving registration.', HTTP_INTERNAL_SERVER_ERROR)
diff --git a/tests/components/notify/test_html5.py b/tests/components/notify/test_html5.py
index c3998b6db64..6fb2e6454de 100644
--- a/tests/components/notify/test_html5.py
+++ b/tests/components/notify/test_html5.py
@@ -4,10 +4,14 @@ import json
from unittest.mock import patch, MagicMock, mock_open
from aiohttp.hdrs import AUTHORIZATION
+from homeassistant.exceptions import HomeAssistantError
+from homeassistant.util.json import save_json
from homeassistant.components.notify import html5
from tests.common import mock_http_component_app
+CONFIG_FILE = 'file.conf'
+
SUBSCRIPTION_1 = {
'browser': 'chrome',
'subscription': {
@@ -108,36 +112,30 @@ class TestHtml5Notify(object):
'unnamed device': SUBSCRIPTION_1,
}
- m = mock_open()
- with patch(
- 'homeassistant.util.json.open',
- m, create=True
- ):
- hass.config.path.return_value = 'file.conf'
- service = html5.get_service(hass, {})
+ hass.config.path.return_value = CONFIG_FILE
+ service = html5.get_service(hass, {})
- assert service is not None
+ assert service is not None
- # assert hass.called
- assert len(hass.mock_calls) == 3
+ assert len(hass.mock_calls) == 3
- view = hass.mock_calls[1][1][0]
- assert view.json_path == hass.config.path.return_value
- assert view.registrations == {}
+ view = hass.mock_calls[1][1][0]
+ assert view.json_path == hass.config.path.return_value
+ assert view.registrations == {}
- hass.loop = loop
- app = mock_http_component_app(hass)
- view.register(app.router)
- client = yield from test_client(app)
- hass.http.is_banned_ip.return_value = False
- resp = yield from client.post(REGISTER_URL,
- data=json.dumps(SUBSCRIPTION_1))
+ hass.loop = loop
+ app = mock_http_component_app(hass)
+ view.register(app.router)
+ client = yield from test_client(app)
+ hass.http.is_banned_ip.return_value = False
+ resp = yield from client.post(REGISTER_URL,
+ data=json.dumps(SUBSCRIPTION_1))
- content = yield from resp.text()
- assert resp.status == 200, content
- assert view.registrations == expected
- handle = m()
- assert json.loads(handle.write.call_args[0][0]) == expected
+ content = yield from resp.text()
+ assert resp.status == 200, content
+ assert view.registrations == expected
+
+ hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
@asyncio.coroutine
def test_registering_new_device_expiration_view(self, loop, test_client):
@@ -147,36 +145,114 @@ class TestHtml5Notify(object):
'unnamed device': SUBSCRIPTION_4,
}
- m = mock_open()
- with patch(
- 'homeassistant.util.json.open',
- m, create=True
- ):
- hass.config.path.return_value = 'file.conf'
- service = html5.get_service(hass, {})
+ hass.config.path.return_value = CONFIG_FILE
+ service = html5.get_service(hass, {})
- assert service is not None
+ assert service is not None
- # assert hass.called
- assert len(hass.mock_calls) == 3
+ # assert hass.called
+ assert len(hass.mock_calls) == 3
- view = hass.mock_calls[1][1][0]
- assert view.json_path == hass.config.path.return_value
- assert view.registrations == {}
+ view = hass.mock_calls[1][1][0]
+ assert view.json_path == hass.config.path.return_value
+ assert view.registrations == {}
- hass.loop = loop
- app = mock_http_component_app(hass)
- view.register(app.router)
- client = yield from test_client(app)
- hass.http.is_banned_ip.return_value = False
- resp = yield from client.post(REGISTER_URL,
- data=json.dumps(SUBSCRIPTION_4))
+ hass.loop = loop
+ app = mock_http_component_app(hass)
+ view.register(app.router)
+ client = yield from test_client(app)
+ hass.http.is_banned_ip.return_value = False
+ resp = yield from client.post(REGISTER_URL,
+ data=json.dumps(SUBSCRIPTION_4))
- content = yield from resp.text()
- assert resp.status == 200, content
- assert view.registrations == expected
- handle = m()
- assert json.loads(handle.write.call_args[0][0]) == expected
+ content = yield from resp.text()
+ assert resp.status == 200, content
+ assert view.registrations == expected
+
+ hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
+
+ @asyncio.coroutine
+ def test_registering_new_device_fails_view(self, loop, test_client):
+ """Test subs. are not altered when registering a new device fails."""
+ hass = MagicMock()
+ expected = {}
+
+ hass.config.path.return_value = CONFIG_FILE
+ html5.get_service(hass, {})
+ view = hass.mock_calls[1][1][0]
+
+ hass.loop = loop
+ app = mock_http_component_app(hass)
+ view.register(app.router)
+ client = yield from test_client(app)
+ hass.http.is_banned_ip.return_value = False
+
+ hass.async_add_job.side_effect = HomeAssistantError()
+
+ resp = yield from client.post(REGISTER_URL,
+ data=json.dumps(SUBSCRIPTION_1))
+
+ content = yield from resp.text()
+ assert resp.status == 500, content
+ assert view.registrations == expected
+
+ @asyncio.coroutine
+ def test_registering_existing_device_view(self, loop, test_client):
+ """Test subscription is updated when registering existing device."""
+ hass = MagicMock()
+ expected = {
+ 'unnamed device': SUBSCRIPTION_4,
+ }
+
+ hass.config.path.return_value = CONFIG_FILE
+ html5.get_service(hass, {})
+ view = hass.mock_calls[1][1][0]
+
+ hass.loop = loop
+ app = mock_http_component_app(hass)
+ view.register(app.router)
+ client = yield from test_client(app)
+ hass.http.is_banned_ip.return_value = False
+
+ yield from client.post(REGISTER_URL,
+ data=json.dumps(SUBSCRIPTION_1))
+ resp = yield from client.post(REGISTER_URL,
+ data=json.dumps(SUBSCRIPTION_4))
+
+ content = yield from resp.text()
+ assert resp.status == 200, content
+ assert view.registrations == expected
+
+ hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
+
+ @asyncio.coroutine
+ def test_registering_existing_device_fails_view(self, loop, test_client):
+ """Test sub. is not updated when registering existing device fails."""
+ hass = MagicMock()
+ expected = {
+ 'unnamed device': SUBSCRIPTION_1,
+ }
+
+ hass.config.path.return_value = CONFIG_FILE
+ html5.get_service(hass, {})
+ view = hass.mock_calls[1][1][0]
+
+ hass.loop = loop
+ app = mock_http_component_app(hass)
+ view.register(app.router)
+ client = yield from test_client(app)
+ hass.http.is_banned_ip.return_value = False
+
+ yield from client.post(REGISTER_URL,
+ data=json.dumps(SUBSCRIPTION_1))
+
+ hass.async_add_job.side_effect = HomeAssistantError()
+ resp = yield from client.post(REGISTER_URL,
+ data=json.dumps(SUBSCRIPTION_4))
+
+ content = yield from resp.text()
+ assert resp.status == 500, content
+ assert view.registrations == expected
@asyncio.coroutine
def test_registering_new_device_validation(self, loop, test_client):
@@ -188,7 +264,7 @@ class TestHtml5Notify(object):
'homeassistant.util.json.open',
m, create=True
):
- hass.config.path.return_value = 'file.conf'
+ hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert service is not None
@@ -240,7 +316,7 @@ class TestHtml5Notify(object):
'homeassistant.util.json.open',
m, create=True
):
- hass.config.path.return_value = 'file.conf'
+ hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert service is not None
@@ -266,8 +342,9 @@ class TestHtml5Notify(object):
assert resp.status == 200, resp.response
assert view.registrations == config
- handle = m()
- assert json.loads(handle.write.call_args[0][0]) == config
+
+ hass.async_add_job.assert_called_with(save_json, CONFIG_FILE,
+ config)
@asyncio.coroutine
def test_unregister_device_view_handle_unknown_subscription(
@@ -285,7 +362,7 @@ class TestHtml5Notify(object):
'homeassistant.util.json.open',
m, create=True
):
- hass.config.path.return_value = 'file.conf'
+ hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert service is not None
@@ -309,13 +386,13 @@ class TestHtml5Notify(object):
assert resp.status == 200, resp.response
assert view.registrations == config
- handle = m()
- assert handle.write.call_count == 0
+
+ hass.async_add_job.assert_not_called()
@asyncio.coroutine
- def test_unregistering_device_view_handles_json_safe_error(
+ def test_unregistering_device_view_handles_save_error(
self, loop, test_client):
- """Test that the HTML unregister view handles JSON write errors."""
+ """Test that the HTML unregister view handles save errors."""
hass = MagicMock()
config = {
@@ -328,7 +405,7 @@ class TestHtml5Notify(object):
'homeassistant.util.json.open',
m, create=True
):
- hass.config.path.return_value = 'file.conf'
+ hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert service is not None
@@ -346,16 +423,13 @@ class TestHtml5Notify(object):
client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
- with patch('homeassistant.components.notify.html5.save_json',
- return_value=False):
- resp = yield from client.delete(REGISTER_URL, data=json.dumps({
- 'subscription': SUBSCRIPTION_1['subscription'],
- }))
+ hass.async_add_job.side_effect = HomeAssistantError()
+ resp = yield from client.delete(REGISTER_URL, data=json.dumps({
+ 'subscription': SUBSCRIPTION_1['subscription'],
+ }))
assert resp.status == 500, resp.response
assert view.registrations == config
- handle = m()
- assert handle.write.call_count == 0
@asyncio.coroutine
def test_callback_view_no_jwt(self, loop, test_client):
@@ -367,7 +441,7 @@ class TestHtml5Notify(object):
'homeassistant.util.json.open',
m, create=True
):
- hass.config.path.return_value = 'file.conf'
+ hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {})
assert service is not None
@@ -404,7 +478,7 @@ class TestHtml5Notify(object):
'homeassistant.util.json.open',
m, create=True
):
- hass.config.path.return_value = 'file.conf'
+ hass.config.path.return_value = CONFIG_FILE
service = html5.get_service(hass, {'gcm_sender_id': '100'})
assert service is not None