diff --git a/.coveragerc b/.coveragerc index 54927f74c45..b64699f685f 100644 --- a/.coveragerc +++ b/.coveragerc @@ -510,6 +510,7 @@ omit = homeassistant/components/image_processing/dlib_face_detect.py homeassistant/components/image_processing/dlib_face_identify.py homeassistant/components/image_processing/seven_segments.py + homeassistant/components/image_processing/tensorflow.py homeassistant/components/keyboard_remote.py homeassistant/components/keyboard.py homeassistant/components/light/avion.py diff --git a/homeassistant/components/image_processing/tensorflow.py b/homeassistant/components/image_processing/tensorflow.py new file mode 100644 index 00000000000..f333aa1767c --- /dev/null +++ b/homeassistant/components/image_processing/tensorflow.py @@ -0,0 +1,347 @@ +""" +Component that performs TensorFlow classification on images. + +For a quick start, pick a pre-trained COCO model from: +https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md + +For more details about this platform, please refer to the documentation at +https://home-assistant.io/components/image_processing.tensorflow/ +""" +import logging +import sys +import os + +import voluptuous as vol + +from homeassistant.components.image_processing import ( + CONF_CONFIDENCE, CONF_ENTITY_ID, CONF_NAME, CONF_SOURCE, PLATFORM_SCHEMA, + ImageProcessingEntity) +from homeassistant.core import split_entity_id +from homeassistant.helpers import template +import homeassistant.helpers.config_validation as cv + +REQUIREMENTS = ['numpy==1.15.3', 'pillow==5.2.0', + 'protobuf==3.6.1', 'tensorflow==1.11.0'] + +_LOGGER = logging.getLogger(__name__) + +ATTR_MATCHES = 'matches' +ATTR_SUMMARY = 'summary' +ATTR_TOTAL_MATCHES = 'total_matches' + +CONF_FILE_OUT = 'file_out' +CONF_MODEL = 'model' +CONF_GRAPH = 'graph' +CONF_LABELS = 'labels' +CONF_MODEL_DIR = 'model_dir' +CONF_CATEGORIES = 'categories' +CONF_CATEGORY = 'category' +CONF_AREA = 'area' +CONF_TOP = 'top' +CONF_LEFT = 'left' +CONF_BOTTOM = 'bottom' +CONF_RIGHT = 'right' + +AREA_SCHEMA = vol.Schema({ + vol.Optional(CONF_TOP, default=0): cv.small_float, + vol.Optional(CONF_LEFT, default=0): cv.small_float, + vol.Optional(CONF_BOTTOM, default=1): cv.small_float, + vol.Optional(CONF_RIGHT, default=1): cv.small_float +}) + +CATEGORY_SCHEMA = vol.Schema({ + vol.Required(CONF_CATEGORY): cv.string, + vol.Optional(CONF_AREA): AREA_SCHEMA +}) + +PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ + vol.Optional(CONF_FILE_OUT, default=[]): + vol.All(cv.ensure_list, [cv.template]), + vol.Required(CONF_MODEL): vol.Schema({ + vol.Required(CONF_GRAPH): cv.isfile, + vol.Optional(CONF_LABELS): cv.isfile, + vol.Optional(CONF_MODEL_DIR): cv.isdir, + vol.Optional(CONF_AREA): AREA_SCHEMA, + vol.Optional(CONF_CATEGORIES, default=[]): + vol.All(cv.ensure_list, [vol.Any( + cv.string, + CATEGORY_SCHEMA + )]) + }) +}) + + +def draw_box(draw, box, img_width, + img_height, text='', color=(255, 255, 0)): + """Draw bounding box on image.""" + ymin, xmin, ymax, xmax = box + (left, right, top, bottom) = (xmin * img_width, xmax * img_width, + ymin * img_height, ymax * img_height) + draw.line([(left, top), (left, bottom), (right, bottom), + (right, top), (left, top)], width=5, fill=color) + if text: + draw.text((left, abs(top-15)), text, fill=color) + + +def setup_platform(hass, config, add_entities, discovery_info=None): + """Set up the TensorFlow image processing platform.""" + model_config = config.get(CONF_MODEL) + model_dir = model_config.get(CONF_MODEL_DIR) \ + or hass.config.path('tensorflow') + labels = model_config.get(CONF_LABELS) \ + or hass.config.path('tensorflow', 'object_detection', + 'data', 'mscoco_label_map.pbtxt') + + # Make sure locations exist + if not os.path.isdir(model_dir) or not os.path.exists(labels): + _LOGGER.error("Unable to locate tensorflow models or label map.") + return + + # append custom model path to sys.path + sys.path.append(model_dir) + + try: + # Verify that the TensorFlow Object Detection API is pre-installed + # pylint: disable=unused-import,unused-variable + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + import tensorflow as tf # noqa + from object_detection.utils import label_map_util # noqa + except ImportError: + # pylint: disable=line-too-long + _LOGGER.error( + "No TensorFlow Object Detection library found! Install or compile " + "for your system following instructions here: " + "https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md") # noqa + return + + try: + # Display warning that PIL will be used if no OpenCV is found. + # pylint: disable=unused-import,unused-variable + import cv2 # noqa + except ImportError: + _LOGGER.warning("No OpenCV library found. " + "TensorFlow will process image with " + "PIL at reduced resolution.") + + # setup tensorflow graph, session, and label map to pass to processor + # pylint: disable=no-member + detection_graph = tf.Graph() + with detection_graph.as_default(): + od_graph_def = tf.GraphDef() + with tf.gfile.GFile(model_config.get(CONF_GRAPH), 'rb') as fid: + serialized_graph = fid.read() + od_graph_def.ParseFromString(serialized_graph) + tf.import_graph_def(od_graph_def, name='') + + session = tf.Session(graph=detection_graph) + label_map = label_map_util.load_labelmap(labels) + categories = label_map_util.convert_label_map_to_categories( + label_map, max_num_classes=90, use_display_name=True) + category_index = label_map_util.create_category_index(categories) + + entities = [] + + for camera in config[CONF_SOURCE]: + entities.append(TensorFlowImageProcessor( + hass, camera[CONF_ENTITY_ID], camera.get(CONF_NAME), + session, detection_graph, category_index, config)) + + add_entities(entities) + + +class TensorFlowImageProcessor(ImageProcessingEntity): + """Representation of an TensorFlow image processor.""" + + def __init__(self, hass, camera_entity, name, session, detection_graph, + category_index, config): + """Initialize the TensorFlow entity.""" + model_config = config.get(CONF_MODEL) + self.hass = hass + self._camera_entity = camera_entity + if name: + self._name = name + else: + self._name = "TensorFlow {0}".format( + split_entity_id(camera_entity)[1]) + self._session = session + self._graph = detection_graph + self._category_index = category_index + self._min_confidence = config.get(CONF_CONFIDENCE) + self._file_out = config.get(CONF_FILE_OUT) + + # handle categories and specific detection areas + categories = model_config.get(CONF_CATEGORIES) + self._include_categories = [] + self._category_areas = {} + for category in categories: + if isinstance(category, dict): + category_name = category.get(CONF_CATEGORY) + category_area = category.get(CONF_AREA) + self._include_categories.append(category_name) + self._category_areas[category_name] = [0, 0, 1, 1] + if category_area: + self._category_areas[category_name] = [ + category_area.get(CONF_TOP), + category_area.get(CONF_LEFT), + category_area.get(CONF_BOTTOM), + category_area.get(CONF_RIGHT) + ] + else: + self._include_categories.append(category) + self._category_areas[category] = [0, 0, 1, 1] + + # Handle global detection area + self._area = [0, 0, 1, 1] + area_config = model_config.get(CONF_AREA) + if area_config: + self._area = [ + area_config.get(CONF_TOP), + area_config.get(CONF_LEFT), + area_config.get(CONF_BOTTOM), + area_config.get(CONF_RIGHT) + ] + + template.attach(hass, self._file_out) + + self._matches = {} + self._total_matches = 0 + self._last_image = None + + @property + def camera_entity(self): + """Return camera entity id from process pictures.""" + return self._camera_entity + + @property + def name(self): + """Return the name of the image processor.""" + return self._name + + @property + def state(self): + """Return the state of the entity.""" + return self._total_matches + + @property + def device_state_attributes(self): + """Return device specific state attributes.""" + return { + ATTR_MATCHES: self._matches, + ATTR_SUMMARY: {category: len(values) + for category, values in self._matches.items()}, + ATTR_TOTAL_MATCHES: self._total_matches + } + + def _save_image(self, image, matches, paths): + from PIL import Image, ImageDraw + import io + img = Image.open(io.BytesIO(bytearray(image))).convert('RGB') + img_width, img_height = img.size + draw = ImageDraw.Draw(img) + + # Draw custom global region/area + if self._area != [0, 0, 1, 1]: + draw_box(draw, self._area, + img_width, img_height, + "Detection Area", (0, 255, 255)) + + for category, values in matches.items(): + # Draw custom category regions/areas + if self._category_areas[category] != [0, 0, 1, 1]: + label = "{} Detection Area".format(category.capitalize()) + draw_box(draw, self._category_areas[category], img_width, + img_height, label, (0, 255, 0)) + + # Draw detected objects + for instance in values: + label = "{0} {1:.1f}%".format(category, instance['score']) + draw_box(draw, instance['box'], + img_width, img_height, + label, (255, 255, 0)) + + for path in paths: + _LOGGER.info("Saving results image to %s", path) + img.save(path) + + def process_image(self, image): + """Process the image.""" + import numpy as np + + try: + import cv2 # pylint: disable=import-error + img = cv2.imdecode( + np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED) + inp = img[:, :, [2, 1, 0]] # BGR->RGB + inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3) + except ImportError: + from PIL import Image + import io + img = Image.open(io.BytesIO(bytearray(image))).convert('RGB') + img.thumbnail((460, 460), Image.ANTIALIAS) + img_width, img_height = img.size + inp = np.array(img.getdata()).reshape( + (img_height, img_width, 3)).astype(np.uint8) + inp_expanded = np.expand_dims(inp, axis=0) + + image_tensor = self._graph.get_tensor_by_name('image_tensor:0') + boxes = self._graph.get_tensor_by_name('detection_boxes:0') + scores = self._graph.get_tensor_by_name('detection_scores:0') + classes = self._graph.get_tensor_by_name('detection_classes:0') + boxes, scores, classes = self._session.run( + [boxes, scores, classes], + feed_dict={image_tensor: inp_expanded}) + boxes, scores, classes = map(np.squeeze, [boxes, scores, classes]) + classes = classes.astype(int) + + matches = {} + total_matches = 0 + for box, score, obj_class in zip(boxes, scores, classes): + score = score * 100 + boxes = box.tolist() + + # Exclude matches below min confidence value + if score < self._min_confidence: + continue + + # Exclude matches outside global area definition + if (boxes[0] < self._area[0] or boxes[1] < self._area[1] + or boxes[2] > self._area[2] or boxes[3] > self._area[3]): + continue + + category = self._category_index[obj_class]['name'] + + # Exclude unlisted categories + if (self._include_categories + and category not in self._include_categories): + continue + + # Exclude matches outside category specific area definition + if (self._category_areas + and (boxes[0] < self._category_areas[category][0] + or boxes[1] < self._category_areas[category][1] + or boxes[2] > self._category_areas[category][2] + or boxes[3] > self._category_areas[category][3])): + continue + + # If we got here, we should include it + if category not in matches.keys(): + matches[category] = [] + matches[category].append({ + 'score': float(score), + 'box': boxes + }) + total_matches += 1 + + # Save Images + if total_matches and self._file_out: + paths = [] + for path_template in self._file_out: + if isinstance(path_template, template.Template): + paths.append(path_template.render( + camera_entity=self._camera_entity)) + else: + paths.append(path_template) + self._save_image(image, matches, paths) + + self._matches = matches + self._total_matches = total_matches diff --git a/requirements_all.txt b/requirements_all.txt index 42aa2c08cc8..4d4a9020b98 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -669,6 +669,7 @@ nuheat==0.3.0 # homeassistant.components.binary_sensor.trend # homeassistant.components.image_processing.opencv +# homeassistant.components.image_processing.tensorflow # homeassistant.components.sensor.pollen numpy==1.15.3 @@ -722,6 +723,7 @@ piglow==1.2.4 pilight==0.1.1 # homeassistant.components.camera.proxy +# homeassistant.components.image_processing.tensorflow pillow==5.2.0 # homeassistant.components.dominos @@ -747,6 +749,9 @@ proliphix==0.4.1 # homeassistant.components.prometheus prometheus_client==0.2.0 +# homeassistant.components.image_processing.tensorflow +protobuf==3.6.1 + # homeassistant.components.sensor.systemmonitor psutil==5.4.8 @@ -1464,6 +1469,9 @@ temescal==0.1 # homeassistant.components.sensor.temper temperusb==1.5.3 +# homeassistant.components.image_processing.tensorflow +tensorflow==1.11.0 + # homeassistant.components.tesla teslajsonpy==0.0.23 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 6059917892c..8bc7f98a8ea 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -118,6 +118,7 @@ mficlient==0.3.0 # homeassistant.components.binary_sensor.trend # homeassistant.components.image_processing.opencv +# homeassistant.components.image_processing.tensorflow # homeassistant.components.sensor.pollen numpy==1.15.3