Influx import improvements (#11988)

* Influx import improvements

* fix line length issues

* fixing pylint spaces

* Added refined except clause

* Fix progress bar and exclude issues

* fix travis lint too many blank lines

* Minor changes
This commit is contained in:
Taylor Peet 2018-01-31 05:39:15 -05:00 committed by Fabian Affolter
parent 0376cc0917
commit 434d2afbfc

View File

@ -1,7 +1,8 @@
"""Script to import recorded data into influxdb.""" """Script to import recorded data into an Influx database."""
import argparse import argparse
import json import json
import os import os
import sys
from typing import List from typing import List
@ -11,11 +12,13 @@ import homeassistant.config as config_util
def run(script_args: List) -> int: def run(script_args: List) -> int:
"""Run the actual script.""" """Run the actual script."""
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy import func
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from influxdb import InfluxDBClient from influxdb import InfluxDBClient
from homeassistant.components.recorder import models from homeassistant.components.recorder import models
from homeassistant.helpers import state as state_helper from homeassistant.helpers import state as state_helper
from homeassistant.core import State from homeassistant.core import State
from homeassistant.core import HomeAssistantError
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="import data to influxDB.") description="import data to influxDB.")
@ -99,8 +102,8 @@ def run(script_args: List) -> int:
client = None client = None
if not simulate: if not simulate:
client = InfluxDBClient(args.host, args.port, client = InfluxDBClient(
args.username, args.password) args.host, args.port, args.username, args.password)
client.switch_database(args.dbname) client.switch_database(args.dbname)
config_dir = os.path.join(os.getcwd(), args.config) # type: str config_dir = os.path.join(os.getcwd(), args.config) # type: str
@ -116,105 +119,162 @@ def run(script_args: List) -> int:
if not os.path.exists(src_db) and not args.uri: if not os.path.exists(src_db) and not args.uri:
print("Fatal Error: Database '{}' does not exist " print("Fatal Error: Database '{}' does not exist "
"and no uri given".format(src_db)) "and no URI given".format(src_db))
return 1 return 1
uri = args.uri or "sqlite:///{}".format(src_db) uri = args.uri or 'sqlite:///{}'.format(src_db)
engine = create_engine(uri, echo=False) engine = create_engine(uri, echo=False)
session_factory = sessionmaker(bind=engine) session_factory = sessionmaker(bind=engine)
session = session_factory() session = session_factory()
step = int(args.step) step = int(args.step)
step_start = 0
tags = {} tags = {}
if args.tags: if args.tags:
tags.update(dict(elem.split(":") for elem in args.tags.split(","))) tags.update(dict(elem.split(':') for elem in args.tags.split(',')))
excl_entities = args.exclude_entities.split(",") excl_entities = args.exclude_entities.split(',')
excl_domains = args.exclude_domains.split(",") excl_domains = args.exclude_domains.split(',')
override_measurement = args.override_measurement override_measurement = args.override_measurement
default_measurement = args.default_measurement default_measurement = args.default_measurement
query = session.query(models.Events).filter( query = session.query(func.count(models.Events.event_type)).filter(
models.Events.event_type == "state_changed").order_by( models.Events.event_type == 'state_changed')
models.Events.time_fired)
total_events = query.scalar()
prefix_format = '{} of {}'
points = [] points = []
invalid_points = []
count = 0 count = 0
from collections import defaultdict from collections import defaultdict
entities = defaultdict(int) entities = defaultdict(int)
print_progress(0, total_events, prefix_format.format(0, total_events))
for event in query: while True:
event_data = json.loads(event.event_data)
state = State.from_dict(event_data.get("new_state"))
if not state or ( step_stop = step_start + step
excl_entities and state.entity_id in excl_entities) or ( if step_start > total_events:
excl_domains and state.domain in excl_domains): print_progress(total_events, total_events, prefix_format.format(
session.expunge(event) total_events, total_events))
continue break
query = session.query(models.Events).filter(
models.Events.event_type == 'state_changed').order_by(
models.Events.time_fired).slice(step_start, step_stop)
try: for event in query:
_state = float(state_helper.state_as_number(state)) event_data = json.loads(event.event_data)
_state_key = "value"
except ValueError:
_state = state.state
_state_key = "state"
if override_measurement: if not ('entity_id' in event_data) or (
measurement = override_measurement excl_entities and event_data[
else: 'entity_id'] in excl_entities) or (
measurement = state.attributes.get('unit_of_measurement') excl_domains and event_data[
if measurement in (None, ''): 'entity_id'].split('.')[0] in excl_domains):
if default_measurement: session.expunge(event)
measurement = default_measurement continue
else:
measurement = state.entity_id
point = { try:
'measurement': measurement, state = State.from_dict(event_data.get('new_state'))
'tags': { except HomeAssistantError:
'domain': state.domain, invalid_points.append(event_data)
'entity_id': state.object_id,
}, if not state:
'time': event.time_fired, invalid_points.append(event_data)
'fields': { continue
_state_key: _state,
try:
_state = float(state_helper.state_as_number(state))
_state_key = 'value'
except ValueError:
_state = state.state
_state_key = 'state'
if override_measurement:
measurement = override_measurement
else:
measurement = state.attributes.get('unit_of_measurement')
if measurement in (None, ''):
if default_measurement:
measurement = default_measurement
else:
measurement = state.entity_id
point = {
'measurement': measurement,
'tags': {
'domain': state.domain,
'entity_id': state.object_id,
},
'time': event.time_fired,
'fields': {
_state_key: _state,
}
} }
}
for key, value in state.attributes.items(): for key, value in state.attributes.items():
if key != 'unit_of_measurement': if key != 'unit_of_measurement':
# If the key is already in fields # If the key is already in fields
if key in point['fields']: if key in point['fields']:
key = key + "_" key = key + '_'
# Prevent column data errors in influxDB. # Prevent column data errors in influxDB.
# For each value we try to cast it as float # For each value we try to cast it as float
# But if we can not do it we store the value # But if we can not do it we store the value
# as string add "_str" postfix to the field key # as string add "_str" postfix to the field key
try: try:
point['fields'][key] = float(value) point['fields'][key] = float(value)
except (ValueError, TypeError): except (ValueError, TypeError):
new_key = "{}_str".format(key) new_key = '{}_str'.format(key)
point['fields'][new_key] = str(value) point['fields'][new_key] = str(value)
entities[state.entity_id] += 1 entities[state.entity_id] += 1
point['tags'].update(tags) point['tags'].update(tags)
points.append(point) points.append(point)
session.expunge(event) session.expunge(event)
if len(points) >= step:
if points:
if not simulate: if not simulate:
print("Write {} points to the database".format(len(points)))
client.write_points(points) client.write_points(points)
count += len(points) count += len(points)
points = [] # This prevents the progress bar from going over 100% when
# the last step happens
print_progress((step_start + len(
points)), total_events, prefix_format.format(
step_start, total_events))
else:
print_progress(
(step_start + step), total_events, prefix_format.format(
step_start, total_events))
if points: points = []
if not simulate: step_start += step
print("Write {} points to the database".format(len(points)))
client.write_points(points)
count += len(points)
print("\nStatistics:") print("\nStatistics:")
print("\n".join(["{:6}: {}".format(v, k) for k, v print("\n".join(["{:6}: {}".format(v, k) for k, v
in sorted(entities.items(), key=lambda x: x[1])])) in sorted(entities.items(), key=lambda x: x[1])]))
print("\nImport finished {} points written".format(count)) print("\nInvalid Points: {}".format(len(invalid_points)))
print("\nImport finished: {} points written".format(count))
return 0 return 0
# Based on code at
# http://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console
def print_progress(iteration: int, total: int, prefix: str='', suffix: str='',
decimals: int=2, bar_length: int=68) -> None:
"""Print progress bar.
Call in a loop to create terminal progress bar
@params:
iteration - Required : current iteration (Int)
total - Required : total iterations (Int)
prefix - Optional : prefix string (Str)
suffix - Optional : suffix string (Str)
decimals - Optional : number of decimals in percent complete (Int)
barLength - Optional : character length of bar (Int)
"""
filled_length = int(round(bar_length * iteration / float(total)))
percents = round(100.00 * (iteration / float(total)), decimals)
line = '#' * filled_length + '-' * (bar_length - filled_length)
sys.stdout.write('%s [%s] %s%s %s\r' % (prefix, line,
percents, '%', suffix))
sys.stdout.flush()
if iteration == total:
print('\n')