Allow usage of words domain, service, call_id in service data

This commit is contained in:
Paulus Schoutsen 2016-01-30 15:16:31 -08:00
parent fd6086a5d6
commit b7722ec452
4 changed files with 25 additions and 17 deletions

View File

@ -11,6 +11,7 @@ from homeassistant.core import EventOrigin, State
from homeassistant.components.mqtt import DOMAIN as MQTT_DOMAIN
from homeassistant.components.mqtt import SERVICE_PUBLISH as MQTT_SVC_PUBLISH
from homeassistant.const import (
ATTR_SERVICE_DATA,
MATCH_ALL,
EVENT_TIME_CHANGED,
EVENT_CALL_SERVICE,
@ -46,7 +47,7 @@ def setup(hass, config):
if (
event.data.get('domain') == MQTT_DOMAIN and
event.data.get('service') == MQTT_SVC_PUBLISH and
event.data.get('topic') == pub_topic
event.data[ATTR_SERVICE_DATA].get('topic') == pub_topic
):
return

View File

@ -67,6 +67,7 @@ ATTR_NOW = "now"
# Contains domain, service for a SERVICE_CALL event
ATTR_DOMAIN = "domain"
ATTR_SERVICE = "service"
ATTR_SERVICE_DATA = "service_data"
# Data for a SERVICE_EXECUTED event
ATTR_SERVICE_CALL_ID = "service_call_id"

View File

@ -19,7 +19,7 @@ from homeassistant.const import (
SERVICE_HOMEASSISTANT_STOP, EVENT_TIME_CHANGED, EVENT_STATE_CHANGED,
EVENT_CALL_SERVICE, ATTR_NOW, ATTR_DOMAIN, ATTR_SERVICE, MATCH_ALL,
EVENT_SERVICE_EXECUTED, ATTR_SERVICE_CALL_ID, EVENT_SERVICE_REGISTERED,
TEMP_CELCIUS, TEMP_FAHRENHEIT, ATTR_FRIENDLY_NAME)
TEMP_CELCIUS, TEMP_FAHRENHEIT, ATTR_FRIENDLY_NAME, ATTR_SERVICE_DATA)
from homeassistant.exceptions import (
HomeAssistantError, InvalidEntityFormatError)
import homeassistant.util as util
@ -555,13 +555,14 @@ class Service(object):
class ServiceCall(object):
"""Represents a call to a service."""
__slots__ = ['domain', 'service', 'data']
__slots__ = ['domain', 'service', 'data', 'call_id']
def __init__(self, domain, service, data=None):
def __init__(self, domain, service, data=None, call_id=None):
"""Initialize a service call."""
self.domain = domain
self.service = service
self.data = data or {}
self.call_id = call_id
def __repr__(self):
if self.data:
@ -633,10 +634,13 @@ class ServiceRegistry(object):
the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data.
"""
call_id = self._generate_unique_id()
event_data = service_data or {}
event_data[ATTR_DOMAIN] = domain
event_data[ATTR_SERVICE] = service
event_data[ATTR_SERVICE_CALL_ID] = call_id
event_data = {
ATTR_DOMAIN: domain,
ATTR_SERVICE: service,
ATTR_SERVICE_DATA: service_data,
ATTR_SERVICE_CALL_ID: call_id,
}
if blocking:
executed_event = threading.Event()
@ -658,15 +662,16 @@ class ServiceRegistry(object):
def _event_to_service_call(self, event):
"""Callback for SERVICE_CALLED events from the event bus."""
service_data = dict(event.data)
domain = service_data.pop(ATTR_DOMAIN, None)
service = service_data.pop(ATTR_SERVICE, None)
service_data = event.data.get(ATTR_SERVICE_DATA)
domain = event.data.get(ATTR_DOMAIN)
service = event.data.get(ATTR_SERVICE)
call_id = event.data.get(ATTR_SERVICE_CALL_ID)
if not self.has_service(domain, service):
return
service_handler = self._services[domain][service]
service_call = ServiceCall(domain, service, service_data)
service_call = ServiceCall(domain, service, service_data, call_id)
# Add a job to the pool that calls _execute_service
self._pool.add_job(JobPriority.EVENT_SERVICE,
@ -678,10 +683,9 @@ class ServiceRegistry(object):
service, call = service_and_call
service(call)
if ATTR_SERVICE_CALL_ID in call.data:
if call.call_id is not None:
self._bus.fire(
EVENT_SERVICE_EXECUTED,
{ATTR_SERVICE_CALL_ID: call.data[ATTR_SERVICE_CALL_ID]})
EVENT_SERVICE_EXECUTED, {ATTR_SERVICE_CALL_ID: call.call_id})
def _generate_unique_id(self):
"""Generate a unique service call id."""

View File

@ -63,8 +63,10 @@ class TestMQTT(unittest.TestCase):
self.hass.pool.block_till_done()
self.assertEqual(1, len(self.calls))
self.assertEqual('test-topic', self.calls[0][0].data[mqtt.ATTR_TOPIC])
self.assertEqual('test-payload', self.calls[0][0].data[mqtt.ATTR_PAYLOAD])
self.assertEqual('test-topic',
self.calls[0][0].data['service_data'][mqtt.ATTR_TOPIC])
self.assertEqual('test-payload',
self.calls[0][0].data['service_data'][mqtt.ATTR_PAYLOAD])
def test_service_call_without_topic_does_not_publush(self):
self.hass.bus.fire(EVENT_CALL_SERVICE, {