From 434d2afbfc9315d3d269f0fa0d1c1412c5e9f4e5 Mon Sep 17 00:00:00 2001 From: Taylor Peet Date: Wed, 31 Jan 2018 05:39:15 -0500 Subject: [PATCH] Influx import improvements (#11988) * Influx import improvements * fix line length issues * fixing pylint spaces * Added refined except clause * Fix progress bar and exclude issues * fix travis lint too many blank lines * Minor changes --- homeassistant/scripts/influxdb_import.py | 202 +++++++++++++++-------- 1 file changed, 131 insertions(+), 71 deletions(-) diff --git a/homeassistant/scripts/influxdb_import.py b/homeassistant/scripts/influxdb_import.py index c21ac4adad9..e91aeb8a0d7 100644 --- a/homeassistant/scripts/influxdb_import.py +++ b/homeassistant/scripts/influxdb_import.py @@ -1,7 +1,8 @@ -"""Script to import recorded data into influxdb.""" +"""Script to import recorded data into an Influx database.""" import argparse import json import os +import sys from typing import List @@ -11,11 +12,13 @@ import homeassistant.config as config_util def run(script_args: List) -> int: """Run the actual script.""" from sqlalchemy import create_engine + from sqlalchemy import func from sqlalchemy.orm import sessionmaker from influxdb import InfluxDBClient from homeassistant.components.recorder import models from homeassistant.helpers import state as state_helper from homeassistant.core import State + from homeassistant.core import HomeAssistantError parser = argparse.ArgumentParser( description="import data to influxDB.") @@ -99,8 +102,8 @@ def run(script_args: List) -> int: client = None if not simulate: - client = InfluxDBClient(args.host, args.port, - args.username, args.password) + client = InfluxDBClient( + args.host, args.port, args.username, args.password) client.switch_database(args.dbname) config_dir = os.path.join(os.getcwd(), args.config) # type: str @@ -116,105 +119,162 @@ def run(script_args: List) -> int: if not os.path.exists(src_db) and not args.uri: print("Fatal Error: Database '{}' does not exist " - "and no uri given".format(src_db)) + "and no URI given".format(src_db)) return 1 - uri = args.uri or "sqlite:///{}".format(src_db) + uri = args.uri or 'sqlite:///{}'.format(src_db) engine = create_engine(uri, echo=False) session_factory = sessionmaker(bind=engine) session = session_factory() step = int(args.step) + step_start = 0 tags = {} if args.tags: - tags.update(dict(elem.split(":") for elem in args.tags.split(","))) - excl_entities = args.exclude_entities.split(",") - excl_domains = args.exclude_domains.split(",") + tags.update(dict(elem.split(':') for elem in args.tags.split(','))) + excl_entities = args.exclude_entities.split(',') + excl_domains = args.exclude_domains.split(',') override_measurement = args.override_measurement default_measurement = args.default_measurement - query = session.query(models.Events).filter( - models.Events.event_type == "state_changed").order_by( - models.Events.time_fired) + query = session.query(func.count(models.Events.event_type)).filter( + models.Events.event_type == 'state_changed') + + total_events = query.scalar() + prefix_format = '{} of {}' points = [] + invalid_points = [] count = 0 from collections import defaultdict entities = defaultdict(int) + print_progress(0, total_events, prefix_format.format(0, total_events)) - for event in query: - event_data = json.loads(event.event_data) - state = State.from_dict(event_data.get("new_state")) + while True: - if not state or ( - excl_entities and state.entity_id in excl_entities) or ( - excl_domains and state.domain in excl_domains): - session.expunge(event) - continue + step_stop = step_start + step + if step_start > total_events: + print_progress(total_events, total_events, prefix_format.format( + total_events, total_events)) + break + query = session.query(models.Events).filter( + models.Events.event_type == 'state_changed').order_by( + models.Events.time_fired).slice(step_start, step_stop) - try: - _state = float(state_helper.state_as_number(state)) - _state_key = "value" - except ValueError: - _state = state.state - _state_key = "state" + for event in query: + event_data = json.loads(event.event_data) - if override_measurement: - measurement = override_measurement - else: - measurement = state.attributes.get('unit_of_measurement') - if measurement in (None, ''): - if default_measurement: - measurement = default_measurement - else: - measurement = state.entity_id + if not ('entity_id' in event_data) or ( + excl_entities and event_data[ + 'entity_id'] in excl_entities) or ( + excl_domains and event_data[ + 'entity_id'].split('.')[0] in excl_domains): + session.expunge(event) + continue - point = { - 'measurement': measurement, - 'tags': { - 'domain': state.domain, - 'entity_id': state.object_id, - }, - 'time': event.time_fired, - 'fields': { - _state_key: _state, + try: + state = State.from_dict(event_data.get('new_state')) + except HomeAssistantError: + invalid_points.append(event_data) + + if not state: + invalid_points.append(event_data) + continue + + try: + _state = float(state_helper.state_as_number(state)) + _state_key = 'value' + except ValueError: + _state = state.state + _state_key = 'state' + + if override_measurement: + measurement = override_measurement + else: + measurement = state.attributes.get('unit_of_measurement') + if measurement in (None, ''): + if default_measurement: + measurement = default_measurement + else: + measurement = state.entity_id + + point = { + 'measurement': measurement, + 'tags': { + 'domain': state.domain, + 'entity_id': state.object_id, + }, + 'time': event.time_fired, + 'fields': { + _state_key: _state, + } } - } - for key, value in state.attributes.items(): - if key != 'unit_of_measurement': - # If the key is already in fields - if key in point['fields']: - key = key + "_" - # Prevent column data errors in influxDB. - # For each value we try to cast it as float - # But if we can not do it we store the value - # as string add "_str" postfix to the field key - try: - point['fields'][key] = float(value) - except (ValueError, TypeError): - new_key = "{}_str".format(key) - point['fields'][new_key] = str(value) + for key, value in state.attributes.items(): + if key != 'unit_of_measurement': + # If the key is already in fields + if key in point['fields']: + key = key + '_' + # Prevent column data errors in influxDB. + # For each value we try to cast it as float + # But if we can not do it we store the value + # as string add "_str" postfix to the field key + try: + point['fields'][key] = float(value) + except (ValueError, TypeError): + new_key = '{}_str'.format(key) + point['fields'][new_key] = str(value) - entities[state.entity_id] += 1 - point['tags'].update(tags) - points.append(point) - session.expunge(event) - if len(points) >= step: + entities[state.entity_id] += 1 + point['tags'].update(tags) + points.append(point) + session.expunge(event) + + if points: if not simulate: - print("Write {} points to the database".format(len(points))) client.write_points(points) count += len(points) - points = [] + # This prevents the progress bar from going over 100% when + # the last step happens + print_progress((step_start + len( + points)), total_events, prefix_format.format( + step_start, total_events)) + else: + print_progress( + (step_start + step), total_events, prefix_format.format( + step_start, total_events)) - if points: - if not simulate: - print("Write {} points to the database".format(len(points))) - client.write_points(points) - count += len(points) + points = [] + step_start += step print("\nStatistics:") print("\n".join(["{:6}: {}".format(v, k) for k, v in sorted(entities.items(), key=lambda x: x[1])])) - print("\nImport finished {} points written".format(count)) + print("\nInvalid Points: {}".format(len(invalid_points))) + print("\nImport finished: {} points written".format(count)) return 0 + + +# Based on code at +# http://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console +def print_progress(iteration: int, total: int, prefix: str='', suffix: str='', + decimals: int=2, bar_length: int=68) -> None: + """Print progress bar. + + Call in a loop to create terminal progress bar + @params: + iteration - Required : current iteration (Int) + total - Required : total iterations (Int) + prefix - Optional : prefix string (Str) + suffix - Optional : suffix string (Str) + decimals - Optional : number of decimals in percent complete (Int) + barLength - Optional : character length of bar (Int) + """ + filled_length = int(round(bar_length * iteration / float(total))) + percents = round(100.00 * (iteration / float(total)), decimals) + line = '#' * filled_length + '-' * (bar_length - filled_length) + sys.stdout.write('%s [%s] %s%s %s\r' % (prefix, line, + percents, '%', suffix)) + sys.stdout.flush() + if iteration == total: + print('\n')