From 4cb984842aba70520c4522be3d047286234b7f1b Mon Sep 17 00:00:00 2001 From: Raman Gupta <7243222+raman325@users.noreply.github.com> Date: Wed, 23 Oct 2019 01:26:29 -0400 Subject: [PATCH] Support custom source type for MQTT device tracker (#27838) * support custom source type for MQTT device tracker * fix typo * add abbreviation --- .../components/mqtt/abbreviations.py | 1 + .../components/mqtt/device_tracker.py | 13 +++++--- tests/components/mqtt/test_device_tracker.py | 31 ++++++++++++++++++- 3 files changed, 40 insertions(+), 5 deletions(-) diff --git a/homeassistant/components/mqtt/abbreviations.py b/homeassistant/components/mqtt/abbreviations.py index 5a5ed4555db..5e995494a64 100644 --- a/homeassistant/components/mqtt/abbreviations.py +++ b/homeassistant/components/mqtt/abbreviations.py @@ -130,6 +130,7 @@ ABBREVIATIONS = { "spd_stat_t": "speed_state_topic", "spd_val_tpl": "speed_value_template", "spds": "speeds", + "src_type": "source_type", "stat_clsd": "state_closed", "stat_off": "state_off", "stat_on": "state_on", diff --git a/homeassistant/components/mqtt/device_tracker.py b/homeassistant/components/mqtt/device_tracker.py index c9cce3ebeda..d25d7ce21d3 100644 --- a/homeassistant/components/mqtt/device_tracker.py +++ b/homeassistant/components/mqtt/device_tracker.py @@ -4,7 +4,7 @@ import logging import voluptuous as vol from homeassistant.components import mqtt -from homeassistant.components.device_tracker import PLATFORM_SCHEMA +from homeassistant.components.device_tracker import PLATFORM_SCHEMA, SOURCE_TYPES from homeassistant.core import callback import homeassistant.helpers.config_validation as cv from homeassistant.const import CONF_DEVICES, STATE_NOT_HOME, STATE_HOME @@ -15,12 +15,14 @@ _LOGGER = logging.getLogger(__name__) CONF_PAYLOAD_HOME = "payload_home" CONF_PAYLOAD_NOT_HOME = "payload_not_home" +CONF_SOURCE_TYPE = "source_type" PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(mqtt.SCHEMA_BASE).extend( { vol.Required(CONF_DEVICES): {cv.string: mqtt.valid_subscribe_topic}, vol.Optional(CONF_PAYLOAD_HOME, default=STATE_HOME): cv.string, vol.Optional(CONF_PAYLOAD_NOT_HOME, default=STATE_NOT_HOME): cv.string, + vol.Optional(CONF_SOURCE_TYPE): vol.In(SOURCE_TYPES), } ) @@ -31,6 +33,7 @@ async def async_setup_scanner(hass, config, async_see, discovery_info=None): qos = config[CONF_QOS] payload_home = config[CONF_PAYLOAD_HOME] payload_not_home = config[CONF_PAYLOAD_NOT_HOME] + source_type = config.get(CONF_SOURCE_TYPE) for dev_id, topic in devices.items(): @@ -44,9 +47,11 @@ async def async_setup_scanner(hass, config, async_see, discovery_info=None): else: location_name = msg.payload - hass.async_create_task( - async_see(dev_id=dev_id, location_name=location_name) - ) + see_args = {"dev_id": dev_id, "location_name": location_name} + if source_type: + see_args["source_type"] = source_type + + hass.async_create_task(async_see(**see_args)) await mqtt.async_subscribe(hass, topic, async_message_received, qos) diff --git a/tests/components/mqtt/test_device_tracker.py b/tests/components/mqtt/test_device_tracker.py index 14180d2dcf9..71348fcf5cb 100644 --- a/tests/components/mqtt/test_device_tracker.py +++ b/tests/components/mqtt/test_device_tracker.py @@ -3,7 +3,10 @@ from asynctest import patch import pytest from homeassistant.components import device_tracker -from homeassistant.components.device_tracker.const import ENTITY_ID_FORMAT +from homeassistant.components.device_tracker.const import ( + ENTITY_ID_FORMAT, + SOURCE_TYPE_BLUETOOTH, +) from homeassistant.const import CONF_PLATFORM, STATE_HOME, STATE_NOT_HOME from homeassistant.setup import async_setup_component @@ -218,3 +221,29 @@ async def test_not_matching_custom_payload_for_home_and_not_home( await hass.async_block_till_done() assert hass.states.get(entity_id).state != STATE_HOME assert hass.states.get(entity_id).state != STATE_NOT_HOME + + +async def test_matching_source_type(hass, mock_device_tracker_conf): + """Test setting source type.""" + dev_id = "paulus" + entity_id = ENTITY_ID_FORMAT.format(dev_id) + topic = "/location/paulus" + source_type = SOURCE_TYPE_BLUETOOTH + location = "work" + + hass.config.components = set(["mqtt", "zone"]) + assert await async_setup_component( + hass, + device_tracker.DOMAIN, + { + device_tracker.DOMAIN: { + CONF_PLATFORM: "mqtt", + "devices": {dev_id: topic}, + "source_type": source_type, + } + }, + ) + + async_fire_mqtt_message(hass, topic, location) + await hass.async_block_till_done() + assert hass.states.get(entity_id).attributes["source_type"] == SOURCE_TYPE_BLUETOOTH