Refactor imports for tensorflow (#27617)

* Refactoring imports for tensorflow

* Removing whitespace spaces on blank line 110

* Moving tensorflow to try/except block

* Fixed black formatting

* Refactoring try/except to if/else
This commit is contained in:
Steven D. Lander 2019-10-14 11:44:30 -04:00 committed by Paulus Schoutsen
parent 09de6d5889
commit 5a83a92390

View File

@ -2,8 +2,22 @@
import logging import logging
import os import os
import sys import sys
import io
import voluptuous as vol import voluptuous as vol
from PIL import Image, ImageDraw
import numpy as np
try:
import cv2
except ImportError:
cv2 = None
try:
# Verify that the TensorFlow Object Detection API is pre-installed
import tensorflow as tf # noqa
from object_detection.utils import label_map_util # noqa
except ImportError:
label_map_util = None
from homeassistant.components.image_processing import ( from homeassistant.components.image_processing import (
CONF_CONFIDENCE, CONF_CONFIDENCE,
@ -84,14 +98,8 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
# append custom model path to sys.path # append custom model path to sys.path
sys.path.append(model_dir) sys.path.append(model_dir)
try: os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
# Verify that the TensorFlow Object Detection API is pre-installed if label_map_util is None:
# pylint: disable=unused-import,unused-variable
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf # noqa
from object_detection.utils import label_map_util # noqa
except ImportError:
# pylint: disable=line-too-long
_LOGGER.error( _LOGGER.error(
"No TensorFlow Object Detection library found! Install or compile " "No TensorFlow Object Detection library found! Install or compile "
"for your system following instructions here: " "for your system following instructions here: "
@ -99,11 +107,7 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
) # noqa ) # noqa
return return
try: if cv2 is None:
# Display warning that PIL will be used if no OpenCV is found.
# pylint: disable=unused-import,unused-variable
import cv2 # noqa
except ImportError:
_LOGGER.warning( _LOGGER.warning(
"No OpenCV library found. TensorFlow will process image with " "No OpenCV library found. TensorFlow will process image with "
"PIL at reduced resolution" "PIL at reduced resolution"
@ -236,9 +240,6 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
} }
def _save_image(self, image, matches, paths): def _save_image(self, image, matches, paths):
from PIL import Image, ImageDraw
import io
img = Image.open(io.BytesIO(bytearray(image))).convert("RGB") img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
img_width, img_height = img.size img_width, img_height = img.size
draw = ImageDraw.Draw(img) draw = ImageDraw.Draw(img)
@ -280,18 +281,8 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
def process_image(self, image): def process_image(self, image):
"""Process the image.""" """Process the image."""
import numpy as np
try:
import cv2 # pylint: disable=import-error
img = cv2.imdecode(np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED)
inp = img[:, :, [2, 1, 0]] # BGR->RGB
inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3)
except ImportError:
from PIL import Image
import io
if cv2 is None:
img = Image.open(io.BytesIO(bytearray(image))).convert("RGB") img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
img.thumbnail((460, 460), Image.ANTIALIAS) img.thumbnail((460, 460), Image.ANTIALIAS)
img_width, img_height = img.size img_width, img_height = img.size
@ -301,6 +292,10 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
.astype(np.uint8) .astype(np.uint8)
) )
inp_expanded = np.expand_dims(inp, axis=0) inp_expanded = np.expand_dims(inp, axis=0)
else:
img = cv2.imdecode(np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED)
inp = img[:, :, [2, 1, 0]] # BGR->RGB
inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3)
image_tensor = self._graph.get_tensor_by_name("image_tensor:0") image_tensor = self._graph.get_tensor_by_name("image_tensor:0")
boxes = self._graph.get_tensor_by_name("detection_boxes:0") boxes = self._graph.get_tensor_by_name("detection_boxes:0")