mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 14:17:45 +00:00
Improve type hints in tensorflow (#145433)
* Improve type hints in tensorflow * Use ANTIALIAS again * Use Image.Resampling.LANCZOS
This commit is contained in:
parent
6e74b56649
commit
40267760fd
@ -7,6 +7,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageDraw, UnidentifiedImageError
|
from PIL import Image, ImageDraw, UnidentifiedImageError
|
||||||
@ -54,6 +55,8 @@ CONF_MODEL_DIR = "model_dir"
|
|||||||
CONF_RIGHT = "right"
|
CONF_RIGHT = "right"
|
||||||
CONF_TOP = "top"
|
CONF_TOP = "top"
|
||||||
|
|
||||||
|
_DEFAULT_AREA = (0.0, 0.0, 1.0, 1.0)
|
||||||
|
|
||||||
AREA_SCHEMA = vol.Schema(
|
AREA_SCHEMA = vol.Schema(
|
||||||
{
|
{
|
||||||
vol.Optional(CONF_BOTTOM, default=1): cv.small_float,
|
vol.Optional(CONF_BOTTOM, default=1): cv.small_float,
|
||||||
@ -189,19 +192,21 @@ def setup_platform(
|
|||||||
|
|
||||||
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, tensorflow_hass_start)
|
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, tensorflow_hass_start)
|
||||||
|
|
||||||
category_index = label_map_util.create_category_index_from_labelmap(
|
category_index: dict[int, dict[str, Any]] = (
|
||||||
labels, use_display_name=True
|
label_map_util.create_category_index_from_labelmap(
|
||||||
|
labels, use_display_name=True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
source: list[dict[str, str]] = config[CONF_SOURCE]
|
||||||
add_entities(
|
add_entities(
|
||||||
TensorFlowImageProcessor(
|
TensorFlowImageProcessor(
|
||||||
hass,
|
|
||||||
camera[CONF_ENTITY_ID],
|
camera[CONF_ENTITY_ID],
|
||||||
camera.get(CONF_NAME),
|
camera.get(CONF_NAME),
|
||||||
category_index,
|
category_index,
|
||||||
config,
|
config,
|
||||||
)
|
)
|
||||||
for camera in config[CONF_SOURCE]
|
for camera in source
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -210,78 +215,66 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hass,
|
camera_entity: str,
|
||||||
camera_entity,
|
name: str | None,
|
||||||
name,
|
category_index: dict[int, dict[str, Any]],
|
||||||
category_index,
|
config: ConfigType,
|
||||||
config,
|
) -> None:
|
||||||
):
|
|
||||||
"""Initialize the TensorFlow entity."""
|
"""Initialize the TensorFlow entity."""
|
||||||
model_config = config.get(CONF_MODEL)
|
model_config: dict[str, Any] = config[CONF_MODEL]
|
||||||
self.hass = hass
|
self._attr_camera_entity = camera_entity
|
||||||
self._camera_entity = camera_entity
|
|
||||||
if name:
|
if name:
|
||||||
self._name = name
|
self._attr_name = name
|
||||||
else:
|
else:
|
||||||
self._name = f"TensorFlow {split_entity_id(camera_entity)[1]}"
|
self._attr_name = f"TensorFlow {split_entity_id(camera_entity)[1]}"
|
||||||
self._category_index = category_index
|
self._category_index = category_index
|
||||||
self._min_confidence = config.get(CONF_CONFIDENCE)
|
self._min_confidence = config.get(CONF_CONFIDENCE)
|
||||||
self._file_out = config.get(CONF_FILE_OUT)
|
self._file_out = config.get(CONF_FILE_OUT)
|
||||||
|
|
||||||
# handle categories and specific detection areas
|
# handle categories and specific detection areas
|
||||||
self._label_id_offset = model_config.get(CONF_LABEL_OFFSET)
|
self._label_id_offset = model_config.get(CONF_LABEL_OFFSET)
|
||||||
categories = model_config.get(CONF_CATEGORIES)
|
categories: list[str | dict[str, Any]] = model_config[CONF_CATEGORIES]
|
||||||
self._include_categories = []
|
self._include_categories = []
|
||||||
self._category_areas = {}
|
self._category_areas: dict[str, tuple[float, float, float, float]] = {}
|
||||||
for category in categories:
|
for category in categories:
|
||||||
if isinstance(category, dict):
|
if isinstance(category, dict):
|
||||||
category_name = category.get(CONF_CATEGORY)
|
category_name: str = category[CONF_CATEGORY]
|
||||||
category_area = category.get(CONF_AREA)
|
category_area = category.get(CONF_AREA)
|
||||||
self._include_categories.append(category_name)
|
self._include_categories.append(category_name)
|
||||||
self._category_areas[category_name] = [0, 0, 1, 1]
|
self._category_areas[category_name] = _DEFAULT_AREA
|
||||||
if category_area:
|
if category_area:
|
||||||
self._category_areas[category_name] = [
|
self._category_areas[category_name] = (
|
||||||
category_area.get(CONF_TOP),
|
category_area[CONF_TOP],
|
||||||
category_area.get(CONF_LEFT),
|
category_area[CONF_LEFT],
|
||||||
category_area.get(CONF_BOTTOM),
|
category_area[CONF_BOTTOM],
|
||||||
category_area.get(CONF_RIGHT),
|
category_area[CONF_RIGHT],
|
||||||
]
|
)
|
||||||
else:
|
else:
|
||||||
self._include_categories.append(category)
|
self._include_categories.append(category)
|
||||||
self._category_areas[category] = [0, 0, 1, 1]
|
self._category_areas[category] = _DEFAULT_AREA
|
||||||
|
|
||||||
# Handle global detection area
|
# Handle global detection area
|
||||||
self._area = [0, 0, 1, 1]
|
self._area = _DEFAULT_AREA
|
||||||
if area_config := model_config.get(CONF_AREA):
|
if area_config := model_config.get(CONF_AREA):
|
||||||
self._area = [
|
self._area = (
|
||||||
area_config.get(CONF_TOP),
|
area_config[CONF_TOP],
|
||||||
area_config.get(CONF_LEFT),
|
area_config[CONF_LEFT],
|
||||||
area_config.get(CONF_BOTTOM),
|
area_config[CONF_BOTTOM],
|
||||||
area_config.get(CONF_RIGHT),
|
area_config[CONF_RIGHT],
|
||||||
]
|
)
|
||||||
|
|
||||||
self._matches = {}
|
self._matches: dict[str, list[dict[str, Any]]] = {}
|
||||||
self._total_matches = 0
|
self._total_matches = 0
|
||||||
self._last_image = None
|
self._last_image = None
|
||||||
self._process_time = 0
|
self._process_time = 0.0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def camera_entity(self):
|
def state(self) -> int:
|
||||||
"""Return camera entity id from process pictures."""
|
|
||||||
return self._camera_entity
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
"""Return the name of the image processor."""
|
|
||||||
return self._name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self):
|
|
||||||
"""Return the state of the entity."""
|
"""Return the state of the entity."""
|
||||||
return self._total_matches
|
return self._total_matches
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def extra_state_attributes(self):
|
def extra_state_attributes(self) -> dict[str, Any]:
|
||||||
"""Return device specific state attributes."""
|
"""Return device specific state attributes."""
|
||||||
return {
|
return {
|
||||||
ATTR_MATCHES: self._matches,
|
ATTR_MATCHES: self._matches,
|
||||||
@ -292,25 +285,25 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
|
|||||||
ATTR_PROCESS_TIME: self._process_time,
|
ATTR_PROCESS_TIME: self._process_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _save_image(self, image, matches, paths):
|
def _save_image(
|
||||||
|
self, image: bytes, matches: dict[str, list[dict[str, Any]]], paths: list[str]
|
||||||
|
) -> None:
|
||||||
img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
|
img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
|
||||||
img_width, img_height = img.size
|
img_width, img_height = img.size
|
||||||
draw = ImageDraw.Draw(img)
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
# Draw custom global region/area
|
# Draw custom global region/area
|
||||||
if self._area != [0, 0, 1, 1]:
|
if self._area != _DEFAULT_AREA:
|
||||||
draw_box(
|
draw_box(
|
||||||
draw, self._area, img_width, img_height, "Detection Area", (0, 255, 255)
|
draw, self._area, img_width, img_height, "Detection Area", (0, 255, 255)
|
||||||
)
|
)
|
||||||
|
|
||||||
for category, values in matches.items():
|
for category, values in matches.items():
|
||||||
# Draw custom category regions/areas
|
# Draw custom category regions/areas
|
||||||
if category in self._category_areas and self._category_areas[category] != [
|
if (
|
||||||
0,
|
category in self._category_areas
|
||||||
0,
|
and self._category_areas[category] != _DEFAULT_AREA
|
||||||
1,
|
):
|
||||||
1,
|
|
||||||
]:
|
|
||||||
label = f"{category.capitalize()} Detection Area"
|
label = f"{category.capitalize()} Detection Area"
|
||||||
draw_box(
|
draw_box(
|
||||||
draw,
|
draw,
|
||||||
@ -333,7 +326,7 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
|
|||||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
img.save(path)
|
img.save(path)
|
||||||
|
|
||||||
def process_image(self, image):
|
def process_image(self, image: bytes) -> None:
|
||||||
"""Process the image."""
|
"""Process the image."""
|
||||||
if not (model := self.hass.data[DOMAIN][CONF_MODEL]):
|
if not (model := self.hass.data[DOMAIN][CONF_MODEL]):
|
||||||
_LOGGER.debug("Model not yet ready")
|
_LOGGER.debug("Model not yet ready")
|
||||||
@ -352,7 +345,7 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
|
|||||||
except UnidentifiedImageError:
|
except UnidentifiedImageError:
|
||||||
_LOGGER.warning("Unable to process image, bad data")
|
_LOGGER.warning("Unable to process image, bad data")
|
||||||
return
|
return
|
||||||
img.thumbnail((460, 460), Image.ANTIALIAS)
|
img.thumbnail((460, 460), Image.Resampling.LANCZOS)
|
||||||
img_width, img_height = img.size
|
img_width, img_height = img.size
|
||||||
inp = (
|
inp = (
|
||||||
np.array(img.getdata())
|
np.array(img.getdata())
|
||||||
@ -371,7 +364,7 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
|
|||||||
detections["detection_classes"][0].numpy() + self._label_id_offset
|
detections["detection_classes"][0].numpy() + self._label_id_offset
|
||||||
).astype(int)
|
).astype(int)
|
||||||
|
|
||||||
matches = {}
|
matches: dict[str, list[dict[str, Any]]] = {}
|
||||||
total_matches = 0
|
total_matches = 0
|
||||||
for box, score, obj_class in zip(boxes, scores, classes, strict=False):
|
for box, score, obj_class in zip(boxes, scores, classes, strict=False):
|
||||||
score = score * 100
|
score = score * 100
|
||||||
@ -416,9 +409,7 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
|
|||||||
paths = []
|
paths = []
|
||||||
for path_template in self._file_out:
|
for path_template in self._file_out:
|
||||||
if isinstance(path_template, template.Template):
|
if isinstance(path_template, template.Template):
|
||||||
paths.append(
|
paths.append(path_template.render(camera_entity=self.camera_entity))
|
||||||
path_template.render(camera_entity=self._camera_entity)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
paths.append(path_template)
|
paths.append(path_template)
|
||||||
self._save_image(image, matches, paths)
|
self._save_image(image, matches, paths)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user