Influx import improvements (#11988)

* Influx import improvements

* fix line length issues

* fixing pylint spaces

* Added refined except clause

* Fix progress bar and exclude issues

* fix travis lint too many blank lines

* Minor changes
This commit is contained in:
Taylor Peet 2018-01-31 05:39:15 -05:00 committed by Fabian Affolter
parent 0376cc0917
commit 434d2afbfc

View File

@ -1,7 +1,8 @@
"""Script to import recorded data into influxdb.""" """Script to import recorded data into an Influx database."""
import argparse import argparse
import json import json
import os import os
import sys
from typing import List from typing import List
@ -11,11 +12,13 @@ import homeassistant.config as config_util
def run(script_args: List) -> int: def run(script_args: List) -> int:
"""Run the actual script.""" """Run the actual script."""
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy import func
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from influxdb import InfluxDBClient from influxdb import InfluxDBClient
from homeassistant.components.recorder import models from homeassistant.components.recorder import models
from homeassistant.helpers import state as state_helper from homeassistant.helpers import state as state_helper
from homeassistant.core import State from homeassistant.core import State
from homeassistant.core import HomeAssistantError
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="import data to influxDB.") description="import data to influxDB.")
@ -99,8 +102,8 @@ def run(script_args: List) -> int:
client = None client = None
if not simulate: if not simulate:
client = InfluxDBClient(args.host, args.port, client = InfluxDBClient(
args.username, args.password) args.host, args.port, args.username, args.password)
client.switch_database(args.dbname) client.switch_database(args.dbname)
config_dir = os.path.join(os.getcwd(), args.config) # type: str config_dir = os.path.join(os.getcwd(), args.config) # type: str
@ -116,48 +119,74 @@ def run(script_args: List) -> int:
if not os.path.exists(src_db) and not args.uri: if not os.path.exists(src_db) and not args.uri:
print("Fatal Error: Database '{}' does not exist " print("Fatal Error: Database '{}' does not exist "
"and no uri given".format(src_db)) "and no URI given".format(src_db))
return 1 return 1
uri = args.uri or "sqlite:///{}".format(src_db) uri = args.uri or 'sqlite:///{}'.format(src_db)
engine = create_engine(uri, echo=False) engine = create_engine(uri, echo=False)
session_factory = sessionmaker(bind=engine) session_factory = sessionmaker(bind=engine)
session = session_factory() session = session_factory()
step = int(args.step) step = int(args.step)
step_start = 0
tags = {} tags = {}
if args.tags: if args.tags:
tags.update(dict(elem.split(":") for elem in args.tags.split(","))) tags.update(dict(elem.split(':') for elem in args.tags.split(',')))
excl_entities = args.exclude_entities.split(",") excl_entities = args.exclude_entities.split(',')
excl_domains = args.exclude_domains.split(",") excl_domains = args.exclude_domains.split(',')
override_measurement = args.override_measurement override_measurement = args.override_measurement
default_measurement = args.default_measurement default_measurement = args.default_measurement
query = session.query(models.Events).filter( query = session.query(func.count(models.Events.event_type)).filter(
models.Events.event_type == "state_changed").order_by( models.Events.event_type == 'state_changed')
models.Events.time_fired)
total_events = query.scalar()
prefix_format = '{} of {}'
points = [] points = []
invalid_points = []
count = 0 count = 0
from collections import defaultdict from collections import defaultdict
entities = defaultdict(int) entities = defaultdict(int)
print_progress(0, total_events, prefix_format.format(0, total_events))
while True:
step_stop = step_start + step
if step_start > total_events:
print_progress(total_events, total_events, prefix_format.format(
total_events, total_events))
break
query = session.query(models.Events).filter(
models.Events.event_type == 'state_changed').order_by(
models.Events.time_fired).slice(step_start, step_stop)
for event in query: for event in query:
event_data = json.loads(event.event_data) event_data = json.loads(event.event_data)
state = State.from_dict(event_data.get("new_state"))
if not state or ( if not ('entity_id' in event_data) or (
excl_entities and state.entity_id in excl_entities) or ( excl_entities and event_data[
excl_domains and state.domain in excl_domains): 'entity_id'] in excl_entities) or (
excl_domains and event_data[
'entity_id'].split('.')[0] in excl_domains):
session.expunge(event) session.expunge(event)
continue continue
try:
state = State.from_dict(event_data.get('new_state'))
except HomeAssistantError:
invalid_points.append(event_data)
if not state:
invalid_points.append(event_data)
continue
try: try:
_state = float(state_helper.state_as_number(state)) _state = float(state_helper.state_as_number(state))
_state_key = "value" _state_key = 'value'
except ValueError: except ValueError:
_state = state.state _state = state.state
_state_key = "state" _state_key = 'state'
if override_measurement: if override_measurement:
measurement = override_measurement measurement = override_measurement
@ -185,7 +214,7 @@ def run(script_args: List) -> int:
if key != 'unit_of_measurement': if key != 'unit_of_measurement':
# If the key is already in fields # If the key is already in fields
if key in point['fields']: if key in point['fields']:
key = key + "_" key = key + '_'
# Prevent column data errors in influxDB. # Prevent column data errors in influxDB.
# For each value we try to cast it as float # For each value we try to cast it as float
# But if we can not do it we store the value # But if we can not do it we store the value
@ -193,28 +222,59 @@ def run(script_args: List) -> int:
try: try:
point['fields'][key] = float(value) point['fields'][key] = float(value)
except (ValueError, TypeError): except (ValueError, TypeError):
new_key = "{}_str".format(key) new_key = '{}_str'.format(key)
point['fields'][new_key] = str(value) point['fields'][new_key] = str(value)
entities[state.entity_id] += 1 entities[state.entity_id] += 1
point['tags'].update(tags) point['tags'].update(tags)
points.append(point) points.append(point)
session.expunge(event) session.expunge(event)
if len(points) >= step:
if not simulate:
print("Write {} points to the database".format(len(points)))
client.write_points(points)
count += len(points)
points = []
if points: if points:
if not simulate: if not simulate:
print("Write {} points to the database".format(len(points)))
client.write_points(points) client.write_points(points)
count += len(points) count += len(points)
# This prevents the progress bar from going over 100% when
# the last step happens
print_progress((step_start + len(
points)), total_events, prefix_format.format(
step_start, total_events))
else:
print_progress(
(step_start + step), total_events, prefix_format.format(
step_start, total_events))
points = []
step_start += step
print("\nStatistics:") print("\nStatistics:")
print("\n".join(["{:6}: {}".format(v, k) for k, v print("\n".join(["{:6}: {}".format(v, k) for k, v
in sorted(entities.items(), key=lambda x: x[1])])) in sorted(entities.items(), key=lambda x: x[1])]))
print("\nImport finished {} points written".format(count)) print("\nInvalid Points: {}".format(len(invalid_points)))
print("\nImport finished: {} points written".format(count))
return 0 return 0
# Based on code at
# http://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console
def print_progress(iteration: int, total: int, prefix: str='', suffix: str='',
decimals: int=2, bar_length: int=68) -> None:
"""Print progress bar.
Call in a loop to create terminal progress bar
@params:
iteration - Required : current iteration (Int)
total - Required : total iterations (Int)
prefix - Optional : prefix string (Str)
suffix - Optional : suffix string (Str)
decimals - Optional : number of decimals in percent complete (Int)
barLength - Optional : character length of bar (Int)
"""
filled_length = int(round(bar_length * iteration / float(total)))
percents = round(100.00 * (iteration / float(total)), decimals)
line = '#' * filled_length + '-' * (bar_length - filled_length)
sys.stdout.write('%s [%s] %s%s %s\r' % (prefix, line,
percents, '%', suffix))
sys.stdout.flush()
if iteration == total:
print('\n')