diff --git a/homeassistant/components/influxdb.py b/homeassistant/components/influxdb.py index 430c5cbe4c6..58479b6c14e 100644 --- a/homeassistant/components/influxdb.py +++ b/homeassistant/components/influxdb.py @@ -10,8 +10,8 @@ import voluptuous as vol from homeassistant.const import ( EVENT_STATE_CHANGED, STATE_UNAVAILABLE, STATE_UNKNOWN, CONF_HOST, - CONF_PORT, CONF_SSL, CONF_VERIFY_SSL, CONF_USERNAME, CONF_BLACKLIST, - CONF_PASSWORD, CONF_WHITELIST) + CONF_PORT, CONF_SSL, CONF_VERIFY_SSL, CONF_USERNAME, CONF_PASSWORD, + CONF_EXCLUDE, CONF_INCLUDE, CONF_DOMAINS, CONF_ENTITIES) from homeassistant.helpers import state as state_helper import homeassistant.helpers.config_validation as cv @@ -23,6 +23,7 @@ CONF_DB_NAME = 'database' CONF_TAGS = 'tags' CONF_DEFAULT_MEASUREMENT = 'default_measurement' CONF_OVERRIDE_MEASUREMENT = 'override_measurement' +CONF_BLACKLIST_DOMAINS = "blacklist_domains" DEFAULT_DATABASE = 'home_assistant' DEFAULT_VERIFY_SSL = True @@ -34,8 +35,16 @@ CONFIG_SCHEMA = vol.Schema({ vol.Optional(CONF_HOST): cv.string, vol.Inclusive(CONF_USERNAME, 'authentication'): cv.string, vol.Inclusive(CONF_PASSWORD, 'authentication'): cv.string, - vol.Optional(CONF_BLACKLIST, default=[]): - vol.All(cv.ensure_list, [cv.entity_id]), + vol.Optional(CONF_EXCLUDE, default={}): vol.Schema({ + vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids, + vol.Optional(CONF_DOMAINS, default=[]): + vol.All(cv.ensure_list, [cv.string]) + }), + vol.Optional(CONF_INCLUDE, default={}): vol.Schema({ + vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids, + vol.Optional(CONF_DOMAINS, default=[]): + vol.All(cv.ensure_list, [cv.string]) + }), vol.Optional(CONF_DB_NAME, default=DEFAULT_DATABASE): cv.string, vol.Optional(CONF_PORT): cv.port, vol.Optional(CONF_SSL): cv.boolean, @@ -43,8 +52,6 @@ CONFIG_SCHEMA = vol.Schema({ vol.Optional(CONF_OVERRIDE_MEASUREMENT): cv.string, vol.Optional(CONF_TAGS, default={}): vol.Schema({cv.string: cv.string}), - vol.Optional(CONF_WHITELIST, default=[]): - vol.All(cv.ensure_list, [cv.entity_id]), vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean, }), }, extra=vol.ALLOW_EXTRA) @@ -77,8 +84,12 @@ def setup(hass, config): if CONF_SSL in conf: kwargs['ssl'] = conf[CONF_SSL] - blacklist = conf.get(CONF_BLACKLIST) - whitelist = conf.get(CONF_WHITELIST) + include = conf.get(CONF_INCLUDE, {}) + exclude = conf.get(CONF_EXCLUDE, {}) + whitelist_e = set(include.get(CONF_ENTITIES, [])) + whitelist_d = set(include.get(CONF_DOMAINS, [])) + blacklist_e = set(exclude.get(CONF_ENTITIES, [])) + blacklist_d = set(exclude.get(CONF_DOMAINS, [])) tags = conf.get(CONF_TAGS) default_measurement = conf.get(CONF_DEFAULT_MEASUREMENT) override_measurement = conf.get(CONF_OVERRIDE_MEASUREMENT) @@ -97,11 +108,13 @@ def setup(hass, config): state = event.data.get('new_state') if state is None or state.state in ( STATE_UNKNOWN, '', STATE_UNAVAILABLE) or \ - state.entity_id in blacklist: + state.entity_id in blacklist_e or \ + state.domain in blacklist_d: return try: - if whitelist and state.entity_id not in whitelist: + if (whitelist_e and state.entity_id not in whitelist_e) or \ + (whitelist_d and state.domain not in whitelist_d): return _state = float(state_helper.state_as_number(state)) diff --git a/tests/components/test_influxdb.py b/tests/components/test_influxdb.py index c1ad2672365..ab1f8916c37 100644 --- a/tests/components/test_influxdb.py +++ b/tests/components/test_influxdb.py @@ -96,7 +96,10 @@ class TestInfluxDB(unittest.TestCase): 'host': 'host', 'username': 'user', 'password': 'pass', - 'blacklist': ['fake.blacklisted'] + 'exclude': { + 'entities': ['fake.blacklisted'], + 'domains': ['another_fake'] + } } } assert setup_component(self.hass, influxdb.DOMAIN, config) @@ -273,6 +276,129 @@ class TestInfluxDB(unittest.TestCase): self.assertFalse(mock_client.return_value.write_points.called) mock_client.return_value.write_points.reset_mock() + def test_event_listener_blacklist_domain(self, mock_client): + """Test the event listener against a blacklist.""" + self._setup() + + for domain in ('ok', 'another_fake'): + state = mock.MagicMock( + state=1, domain=domain, + entity_id='{}.something'.format(domain), + object_id='something', attributes={}) + event = mock.MagicMock(data={'new_state': state}, time_fired=12345) + body = [{ + 'measurement': '{}.something'.format(domain), + 'tags': { + 'domain': domain, + 'entity_id': 'something', + }, + 'time': 12345, + 'fields': { + 'value': 1, + }, + }] + self.handler_method(event) + if domain == 'ok': + self.assertEqual( + mock_client.return_value.write_points.call_count, 1 + ) + self.assertEqual( + mock_client.return_value.write_points.call_args, + mock.call(body) + ) + else: + self.assertFalse(mock_client.return_value.write_points.called) + mock_client.return_value.write_points.reset_mock() + + def test_event_listener_whitelist(self, mock_client): + """Test the event listener against a whitelist.""" + config = { + 'influxdb': { + 'host': 'host', + 'username': 'user', + 'password': 'pass', + 'include': { + 'entities': ['fake.included'], + } + } + } + assert setup_component(self.hass, influxdb.DOMAIN, config) + self.handler_method = self.hass.bus.listen.call_args_list[0][0][1] + + for entity_id in ('included', 'default'): + state = mock.MagicMock( + state=1, domain='fake', entity_id='fake.{}'.format(entity_id), + object_id=entity_id, attributes={}) + event = mock.MagicMock(data={'new_state': state}, time_fired=12345) + body = [{ + 'measurement': 'fake.{}'.format(entity_id), + 'tags': { + 'domain': 'fake', + 'entity_id': entity_id, + }, + 'time': 12345, + 'fields': { + 'value': 1, + }, + }] + self.handler_method(event) + if entity_id == 'included': + self.assertEqual( + mock_client.return_value.write_points.call_count, 1 + ) + self.assertEqual( + mock_client.return_value.write_points.call_args, + mock.call(body) + ) + else: + self.assertFalse(mock_client.return_value.write_points.called) + mock_client.return_value.write_points.reset_mock() + + def test_event_listener_whitelist_domain(self, mock_client): + """Test the event listener against a whitelist.""" + config = { + 'influxdb': { + 'host': 'host', + 'username': 'user', + 'password': 'pass', + 'include': { + 'domains': ['fake'], + } + } + } + assert setup_component(self.hass, influxdb.DOMAIN, config) + self.handler_method = self.hass.bus.listen.call_args_list[0][0][1] + + for domain in ('fake', 'another_fake'): + state = mock.MagicMock( + state=1, domain=domain, + entity_id='{}.something'.format(domain), + object_id='something', attributes={}) + event = mock.MagicMock(data={'new_state': state}, time_fired=12345) + body = [{ + 'measurement': '{}.something'.format(domain), + 'tags': { + 'domain': domain, + 'entity_id': 'something', + }, + 'time': 12345, + 'fields': { + 'value': 1, + }, + }] + self.handler_method(event) + if domain == 'fake': + self.assertEqual( + mock_client.return_value.write_points.call_count, 1 + ) + self.assertEqual( + mock_client.return_value.write_points.call_args, + mock.call(body) + ) + else: + self.assertFalse(mock_client.return_value.write_points.called) + mock_client.return_value.write_points.reset_mock() + def test_event_listener_invalid_type(self, mock_client): """Test the event listener when an attirbute has an invalid type.""" self._setup() @@ -343,7 +469,9 @@ class TestInfluxDB(unittest.TestCase): 'username': 'user', 'password': 'pass', 'default_measurement': 'state', - 'blacklist': ['fake.blacklisted'] + 'exclude': { + 'entities': ['fake.blacklisted'] + } } } assert setup_component(self.hass, influxdb.DOMAIN, config)