mirror of
https://github.com/home-assistant/core.git
synced 2025-07-29 16:17:20 +00:00
Refactor imports for tensorflow (#27617)
* Refactoring imports for tensorflow * Removing whitespace spaces on blank line 110 * Moving tensorflow to try/except block * Fixed black formatting * Refactoring try/except to if/else
This commit is contained in:
parent
09de6d5889
commit
5a83a92390
@ -2,8 +2,22 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import io
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
except ImportError:
|
||||||
|
cv2 = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Verify that the TensorFlow Object Detection API is pre-installed
|
||||||
|
import tensorflow as tf # noqa
|
||||||
|
from object_detection.utils import label_map_util # noqa
|
||||||
|
except ImportError:
|
||||||
|
label_map_util = None
|
||||||
|
|
||||||
from homeassistant.components.image_processing import (
|
from homeassistant.components.image_processing import (
|
||||||
CONF_CONFIDENCE,
|
CONF_CONFIDENCE,
|
||||||
@ -84,14 +98,8 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
|
|||||||
# append custom model path to sys.path
|
# append custom model path to sys.path
|
||||||
sys.path.append(model_dir)
|
sys.path.append(model_dir)
|
||||||
|
|
||||||
try:
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
||||||
# Verify that the TensorFlow Object Detection API is pre-installed
|
if label_map_util is None:
|
||||||
# pylint: disable=unused-import,unused-variable
|
|
||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
|
||||||
import tensorflow as tf # noqa
|
|
||||||
from object_detection.utils import label_map_util # noqa
|
|
||||||
except ImportError:
|
|
||||||
# pylint: disable=line-too-long
|
|
||||||
_LOGGER.error(
|
_LOGGER.error(
|
||||||
"No TensorFlow Object Detection library found! Install or compile "
|
"No TensorFlow Object Detection library found! Install or compile "
|
||||||
"for your system following instructions here: "
|
"for your system following instructions here: "
|
||||||
@ -99,11 +107,7 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
|
|||||||
) # noqa
|
) # noqa
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
if cv2 is None:
|
||||||
# Display warning that PIL will be used if no OpenCV is found.
|
|
||||||
# pylint: disable=unused-import,unused-variable
|
|
||||||
import cv2 # noqa
|
|
||||||
except ImportError:
|
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
"No OpenCV library found. TensorFlow will process image with "
|
"No OpenCV library found. TensorFlow will process image with "
|
||||||
"PIL at reduced resolution"
|
"PIL at reduced resolution"
|
||||||
@ -236,9 +240,6 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def _save_image(self, image, matches, paths):
|
def _save_image(self, image, matches, paths):
|
||||||
from PIL import Image, ImageDraw
|
|
||||||
import io
|
|
||||||
|
|
||||||
img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
|
img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
|
||||||
img_width, img_height = img.size
|
img_width, img_height = img.size
|
||||||
draw = ImageDraw.Draw(img)
|
draw = ImageDraw.Draw(img)
|
||||||
@ -280,18 +281,8 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
|
|||||||
|
|
||||||
def process_image(self, image):
|
def process_image(self, image):
|
||||||
"""Process the image."""
|
"""Process the image."""
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
try:
|
|
||||||
import cv2 # pylint: disable=import-error
|
|
||||||
|
|
||||||
img = cv2.imdecode(np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED)
|
|
||||||
inp = img[:, :, [2, 1, 0]] # BGR->RGB
|
|
||||||
inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3)
|
|
||||||
except ImportError:
|
|
||||||
from PIL import Image
|
|
||||||
import io
|
|
||||||
|
|
||||||
|
if cv2 is None:
|
||||||
img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
|
img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
|
||||||
img.thumbnail((460, 460), Image.ANTIALIAS)
|
img.thumbnail((460, 460), Image.ANTIALIAS)
|
||||||
img_width, img_height = img.size
|
img_width, img_height = img.size
|
||||||
@ -301,6 +292,10 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
|
|||||||
.astype(np.uint8)
|
.astype(np.uint8)
|
||||||
)
|
)
|
||||||
inp_expanded = np.expand_dims(inp, axis=0)
|
inp_expanded = np.expand_dims(inp, axis=0)
|
||||||
|
else:
|
||||||
|
img = cv2.imdecode(np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED)
|
||||||
|
inp = img[:, :, [2, 1, 0]] # BGR->RGB
|
||||||
|
inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3)
|
||||||
|
|
||||||
image_tensor = self._graph.get_tensor_by_name("image_tensor:0")
|
image_tensor = self._graph.get_tensor_by_name("image_tensor:0")
|
||||||
boxes = self._graph.get_tensor_by_name("detection_boxes:0")
|
boxes = self._graph.get_tensor_by_name("detection_boxes:0")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user