Improve type hints in tensorflow (#145433)

* Improve type hints in tensorflow

* Use ANTIALIAS again

* Use Image.Resampling.LANCZOS
This commit is contained in:
epenet 2025-05-22 11:23:33 +02:00 committed by GitHub
parent 6e74b56649
commit 40267760fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,6 +7,7 @@ import logging
import os
import sys
import time
from typing import Any
import numpy as np
from PIL import Image, ImageDraw, UnidentifiedImageError
@ -54,6 +55,8 @@ CONF_MODEL_DIR = "model_dir"
CONF_RIGHT = "right"
CONF_TOP = "top"
_DEFAULT_AREA = (0.0, 0.0, 1.0, 1.0)
AREA_SCHEMA = vol.Schema(
{
vol.Optional(CONF_BOTTOM, default=1): cv.small_float,
@ -189,19 +192,21 @@ def setup_platform(
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, tensorflow_hass_start)
category_index = label_map_util.create_category_index_from_labelmap(
labels, use_display_name=True
category_index: dict[int, dict[str, Any]] = (
label_map_util.create_category_index_from_labelmap(
labels, use_display_name=True
)
)
source: list[dict[str, str]] = config[CONF_SOURCE]
add_entities(
TensorFlowImageProcessor(
hass,
camera[CONF_ENTITY_ID],
camera.get(CONF_NAME),
category_index,
config,
)
for camera in config[CONF_SOURCE]
for camera in source
)
@ -210,78 +215,66 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
def __init__(
self,
hass,
camera_entity,
name,
category_index,
config,
):
camera_entity: str,
name: str | None,
category_index: dict[int, dict[str, Any]],
config: ConfigType,
) -> None:
"""Initialize the TensorFlow entity."""
model_config = config.get(CONF_MODEL)
self.hass = hass
self._camera_entity = camera_entity
model_config: dict[str, Any] = config[CONF_MODEL]
self._attr_camera_entity = camera_entity
if name:
self._name = name
self._attr_name = name
else:
self._name = f"TensorFlow {split_entity_id(camera_entity)[1]}"
self._attr_name = f"TensorFlow {split_entity_id(camera_entity)[1]}"
self._category_index = category_index
self._min_confidence = config.get(CONF_CONFIDENCE)
self._file_out = config.get(CONF_FILE_OUT)
# handle categories and specific detection areas
self._label_id_offset = model_config.get(CONF_LABEL_OFFSET)
categories = model_config.get(CONF_CATEGORIES)
categories: list[str | dict[str, Any]] = model_config[CONF_CATEGORIES]
self._include_categories = []
self._category_areas = {}
self._category_areas: dict[str, tuple[float, float, float, float]] = {}
for category in categories:
if isinstance(category, dict):
category_name = category.get(CONF_CATEGORY)
category_name: str = category[CONF_CATEGORY]
category_area = category.get(CONF_AREA)
self._include_categories.append(category_name)
self._category_areas[category_name] = [0, 0, 1, 1]
self._category_areas[category_name] = _DEFAULT_AREA
if category_area:
self._category_areas[category_name] = [
category_area.get(CONF_TOP),
category_area.get(CONF_LEFT),
category_area.get(CONF_BOTTOM),
category_area.get(CONF_RIGHT),
]
self._category_areas[category_name] = (
category_area[CONF_TOP],
category_area[CONF_LEFT],
category_area[CONF_BOTTOM],
category_area[CONF_RIGHT],
)
else:
self._include_categories.append(category)
self._category_areas[category] = [0, 0, 1, 1]
self._category_areas[category] = _DEFAULT_AREA
# Handle global detection area
self._area = [0, 0, 1, 1]
self._area = _DEFAULT_AREA
if area_config := model_config.get(CONF_AREA):
self._area = [
area_config.get(CONF_TOP),
area_config.get(CONF_LEFT),
area_config.get(CONF_BOTTOM),
area_config.get(CONF_RIGHT),
]
self._area = (
area_config[CONF_TOP],
area_config[CONF_LEFT],
area_config[CONF_BOTTOM],
area_config[CONF_RIGHT],
)
self._matches = {}
self._matches: dict[str, list[dict[str, Any]]] = {}
self._total_matches = 0
self._last_image = None
self._process_time = 0
self._process_time = 0.0
@property
def camera_entity(self):
"""Return camera entity id from process pictures."""
return self._camera_entity
@property
def name(self):
"""Return the name of the image processor."""
return self._name
@property
def state(self):
def state(self) -> int:
"""Return the state of the entity."""
return self._total_matches
@property
def extra_state_attributes(self):
def extra_state_attributes(self) -> dict[str, Any]:
"""Return device specific state attributes."""
return {
ATTR_MATCHES: self._matches,
@ -292,25 +285,25 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
ATTR_PROCESS_TIME: self._process_time,
}
def _save_image(self, image, matches, paths):
def _save_image(
self, image: bytes, matches: dict[str, list[dict[str, Any]]], paths: list[str]
) -> None:
img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
img_width, img_height = img.size
draw = ImageDraw.Draw(img)
# Draw custom global region/area
if self._area != [0, 0, 1, 1]:
if self._area != _DEFAULT_AREA:
draw_box(
draw, self._area, img_width, img_height, "Detection Area", (0, 255, 255)
)
for category, values in matches.items():
# Draw custom category regions/areas
if category in self._category_areas and self._category_areas[category] != [
0,
0,
1,
1,
]:
if (
category in self._category_areas
and self._category_areas[category] != _DEFAULT_AREA
):
label = f"{category.capitalize()} Detection Area"
draw_box(
draw,
@ -333,7 +326,7 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
os.makedirs(os.path.dirname(path), exist_ok=True)
img.save(path)
def process_image(self, image):
def process_image(self, image: bytes) -> None:
"""Process the image."""
if not (model := self.hass.data[DOMAIN][CONF_MODEL]):
_LOGGER.debug("Model not yet ready")
@ -352,7 +345,7 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
except UnidentifiedImageError:
_LOGGER.warning("Unable to process image, bad data")
return
img.thumbnail((460, 460), Image.ANTIALIAS)
img.thumbnail((460, 460), Image.Resampling.LANCZOS)
img_width, img_height = img.size
inp = (
np.array(img.getdata())
@ -371,7 +364,7 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
detections["detection_classes"][0].numpy() + self._label_id_offset
).astype(int)
matches = {}
matches: dict[str, list[dict[str, Any]]] = {}
total_matches = 0
for box, score, obj_class in zip(boxes, scores, classes, strict=False):
score = score * 100
@ -416,9 +409,7 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
paths = []
for path_template in self._file_out:
if isinstance(path_template, template.Template):
paths.append(
path_template.render(camera_entity=self._camera_entity)
)
paths.append(path_template.render(camera_entity=self.camera_entity))
else:
paths.append(path_template)
self._save_image(image, matches, paths)