mirror of
https://github.com/home-assistant/core.git
synced 2025-07-18 18:57:06 +00:00
Allow usage of words domain, service, call_id in service data
This commit is contained in:
parent
fd6086a5d6
commit
b7722ec452
@ -11,6 +11,7 @@ from homeassistant.core import EventOrigin, State
|
|||||||
from homeassistant.components.mqtt import DOMAIN as MQTT_DOMAIN
|
from homeassistant.components.mqtt import DOMAIN as MQTT_DOMAIN
|
||||||
from homeassistant.components.mqtt import SERVICE_PUBLISH as MQTT_SVC_PUBLISH
|
from homeassistant.components.mqtt import SERVICE_PUBLISH as MQTT_SVC_PUBLISH
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
|
ATTR_SERVICE_DATA,
|
||||||
MATCH_ALL,
|
MATCH_ALL,
|
||||||
EVENT_TIME_CHANGED,
|
EVENT_TIME_CHANGED,
|
||||||
EVENT_CALL_SERVICE,
|
EVENT_CALL_SERVICE,
|
||||||
@ -46,7 +47,7 @@ def setup(hass, config):
|
|||||||
if (
|
if (
|
||||||
event.data.get('domain') == MQTT_DOMAIN and
|
event.data.get('domain') == MQTT_DOMAIN and
|
||||||
event.data.get('service') == MQTT_SVC_PUBLISH and
|
event.data.get('service') == MQTT_SVC_PUBLISH and
|
||||||
event.data.get('topic') == pub_topic
|
event.data[ATTR_SERVICE_DATA].get('topic') == pub_topic
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -67,6 +67,7 @@ ATTR_NOW = "now"
|
|||||||
# Contains domain, service for a SERVICE_CALL event
|
# Contains domain, service for a SERVICE_CALL event
|
||||||
ATTR_DOMAIN = "domain"
|
ATTR_DOMAIN = "domain"
|
||||||
ATTR_SERVICE = "service"
|
ATTR_SERVICE = "service"
|
||||||
|
ATTR_SERVICE_DATA = "service_data"
|
||||||
|
|
||||||
# Data for a SERVICE_EXECUTED event
|
# Data for a SERVICE_EXECUTED event
|
||||||
ATTR_SERVICE_CALL_ID = "service_call_id"
|
ATTR_SERVICE_CALL_ID = "service_call_id"
|
||||||
|
@ -19,7 +19,7 @@ from homeassistant.const import (
|
|||||||
SERVICE_HOMEASSISTANT_STOP, EVENT_TIME_CHANGED, EVENT_STATE_CHANGED,
|
SERVICE_HOMEASSISTANT_STOP, EVENT_TIME_CHANGED, EVENT_STATE_CHANGED,
|
||||||
EVENT_CALL_SERVICE, ATTR_NOW, ATTR_DOMAIN, ATTR_SERVICE, MATCH_ALL,
|
EVENT_CALL_SERVICE, ATTR_NOW, ATTR_DOMAIN, ATTR_SERVICE, MATCH_ALL,
|
||||||
EVENT_SERVICE_EXECUTED, ATTR_SERVICE_CALL_ID, EVENT_SERVICE_REGISTERED,
|
EVENT_SERVICE_EXECUTED, ATTR_SERVICE_CALL_ID, EVENT_SERVICE_REGISTERED,
|
||||||
TEMP_CELCIUS, TEMP_FAHRENHEIT, ATTR_FRIENDLY_NAME)
|
TEMP_CELCIUS, TEMP_FAHRENHEIT, ATTR_FRIENDLY_NAME, ATTR_SERVICE_DATA)
|
||||||
from homeassistant.exceptions import (
|
from homeassistant.exceptions import (
|
||||||
HomeAssistantError, InvalidEntityFormatError)
|
HomeAssistantError, InvalidEntityFormatError)
|
||||||
import homeassistant.util as util
|
import homeassistant.util as util
|
||||||
@ -555,13 +555,14 @@ class Service(object):
|
|||||||
class ServiceCall(object):
|
class ServiceCall(object):
|
||||||
"""Represents a call to a service."""
|
"""Represents a call to a service."""
|
||||||
|
|
||||||
__slots__ = ['domain', 'service', 'data']
|
__slots__ = ['domain', 'service', 'data', 'call_id']
|
||||||
|
|
||||||
def __init__(self, domain, service, data=None):
|
def __init__(self, domain, service, data=None, call_id=None):
|
||||||
"""Initialize a service call."""
|
"""Initialize a service call."""
|
||||||
self.domain = domain
|
self.domain = domain
|
||||||
self.service = service
|
self.service = service
|
||||||
self.data = data or {}
|
self.data = data or {}
|
||||||
|
self.call_id = call_id
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
if self.data:
|
if self.data:
|
||||||
@ -633,10 +634,13 @@ class ServiceRegistry(object):
|
|||||||
the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data.
|
the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data.
|
||||||
"""
|
"""
|
||||||
call_id = self._generate_unique_id()
|
call_id = self._generate_unique_id()
|
||||||
event_data = service_data or {}
|
|
||||||
event_data[ATTR_DOMAIN] = domain
|
event_data = {
|
||||||
event_data[ATTR_SERVICE] = service
|
ATTR_DOMAIN: domain,
|
||||||
event_data[ATTR_SERVICE_CALL_ID] = call_id
|
ATTR_SERVICE: service,
|
||||||
|
ATTR_SERVICE_DATA: service_data,
|
||||||
|
ATTR_SERVICE_CALL_ID: call_id,
|
||||||
|
}
|
||||||
|
|
||||||
if blocking:
|
if blocking:
|
||||||
executed_event = threading.Event()
|
executed_event = threading.Event()
|
||||||
@ -658,15 +662,16 @@ class ServiceRegistry(object):
|
|||||||
|
|
||||||
def _event_to_service_call(self, event):
|
def _event_to_service_call(self, event):
|
||||||
"""Callback for SERVICE_CALLED events from the event bus."""
|
"""Callback for SERVICE_CALLED events from the event bus."""
|
||||||
service_data = dict(event.data)
|
service_data = event.data.get(ATTR_SERVICE_DATA)
|
||||||
domain = service_data.pop(ATTR_DOMAIN, None)
|
domain = event.data.get(ATTR_DOMAIN)
|
||||||
service = service_data.pop(ATTR_SERVICE, None)
|
service = event.data.get(ATTR_SERVICE)
|
||||||
|
call_id = event.data.get(ATTR_SERVICE_CALL_ID)
|
||||||
|
|
||||||
if not self.has_service(domain, service):
|
if not self.has_service(domain, service):
|
||||||
return
|
return
|
||||||
|
|
||||||
service_handler = self._services[domain][service]
|
service_handler = self._services[domain][service]
|
||||||
service_call = ServiceCall(domain, service, service_data)
|
service_call = ServiceCall(domain, service, service_data, call_id)
|
||||||
|
|
||||||
# Add a job to the pool that calls _execute_service
|
# Add a job to the pool that calls _execute_service
|
||||||
self._pool.add_job(JobPriority.EVENT_SERVICE,
|
self._pool.add_job(JobPriority.EVENT_SERVICE,
|
||||||
@ -678,10 +683,9 @@ class ServiceRegistry(object):
|
|||||||
service, call = service_and_call
|
service, call = service_and_call
|
||||||
service(call)
|
service(call)
|
||||||
|
|
||||||
if ATTR_SERVICE_CALL_ID in call.data:
|
if call.call_id is not None:
|
||||||
self._bus.fire(
|
self._bus.fire(
|
||||||
EVENT_SERVICE_EXECUTED,
|
EVENT_SERVICE_EXECUTED, {ATTR_SERVICE_CALL_ID: call.call_id})
|
||||||
{ATTR_SERVICE_CALL_ID: call.data[ATTR_SERVICE_CALL_ID]})
|
|
||||||
|
|
||||||
def _generate_unique_id(self):
|
def _generate_unique_id(self):
|
||||||
"""Generate a unique service call id."""
|
"""Generate a unique service call id."""
|
||||||
|
@ -63,8 +63,10 @@ class TestMQTT(unittest.TestCase):
|
|||||||
self.hass.pool.block_till_done()
|
self.hass.pool.block_till_done()
|
||||||
|
|
||||||
self.assertEqual(1, len(self.calls))
|
self.assertEqual(1, len(self.calls))
|
||||||
self.assertEqual('test-topic', self.calls[0][0].data[mqtt.ATTR_TOPIC])
|
self.assertEqual('test-topic',
|
||||||
self.assertEqual('test-payload', self.calls[0][0].data[mqtt.ATTR_PAYLOAD])
|
self.calls[0][0].data['service_data'][mqtt.ATTR_TOPIC])
|
||||||
|
self.assertEqual('test-payload',
|
||||||
|
self.calls[0][0].data['service_data'][mqtt.ATTR_PAYLOAD])
|
||||||
|
|
||||||
def test_service_call_without_topic_does_not_publush(self):
|
def test_service_call_without_topic_does_not_publush(self):
|
||||||
self.hass.bus.fire(EVENT_CALL_SERVICE, {
|
self.hass.bus.fire(EVENT_CALL_SERVICE, {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user