mirror of
https://github.com/home-assistant/core.git
synced 2025-08-20 10:50:06 +00:00
.devcontainer
.github
.vscode
homeassistant
auth
backports
brands
components
3_day_blinds
abode
accuweather
acer_projector
acmeda
acomax
actiontec
adax
adguard
ads
advantage_air
aemet
aep_ohio
aep_texas
aftership
agent_dvr
air_quality
airgradient
airly
airnow
airq
airthings
airthings_ble
airtouch4
airtouch5
airvisual
airvisual_pro
airzone
airzone_cloud
aladdin_connect
alarm_control_panel
alarmdecoder
alert
alexa
alpha_vantage
amazon_polly
amberelectric
ambient_network
ambient_station
amcrest
amp_motorization
ampio
analytics
analytics_insights
android_ip_webcam
androidtv
androidtv_remote
anel_pwrctrl
anova
anthemav
anwb_energie
aosmith
apache_kafka
apcupsd
api
appalachianpower
apple_tv
application_credentials
apprise
aprilaire
aprs
apsystems
aqualogic
aquostv
aranet
arcam_fmj
arest
arris_tg2492lg
aruba
arve
arwn
aseko_pool_live
assist_pipeline
asterisk_cdr
asterisk_mbox
asuswrt
atag
aten_pe
atlanticcityelectric
atome
august
august_ble
aurora
aurora_abb_powerone
aussie_broadband
auth
automation
avea
avion
awair
aws
axis
azure_data_explorer
azure_devops
azure_event_hub
azure_service_bus
backup
baf
baidu
balboa
bang_olufsen
bayesian
bbox
beewi_smartclim
bge
binary_sensor
bitcoin
bizkaibus
blackbird
blebox
blink
blinksticklight
bliss_automation
bloc_blinds
blockchain
bloomsky
blue_current
bluemaestro
blueprint
bluesound
bluetooth
bluetooth_adapters
bluetooth_le_tracker
bluetooth_tracker
bmw_connected_drive
bond
bosch_shc
brandt
braviatv
brel_home
bring
broadlink
brother
brottsplatskartan
browser
brunt
bsblan
bswitch
bt_home_hub_5
bt_smarthub
bthome
bticino
bubendorff
buienradar
button
caldav
calendar
camera
canary
cast
ccm15
cert_expiry
channels
circuit
cisco_ios
cisco_mobility_express
cisco_webex_teams
citybikes
clementine
clickatell
clicksend
clicksend_tts
climate
cloud
cloudflare
cmus
co2signal
coautilities
coinbase
color_extractor
comed
comed_hourly_pricing
comelit
comfoconnect
command_line
compensation
concord232
coned
config
configurator
control4
conversation
coolmaster
counter
cover
cozytouch
cppm_tracker
cpuspeed
cribl
crownstone
cups
currencylayer
dacia
daikin
danfoss_air
datadog
date
datetime
ddwrt
debugpy
deconz
decora
decora_wifi
default_config
delijn
delmarva
deluge
demo
denon
denonavr
derivative
devialet
device_automation
device_sun_light_trigger
device_tracker
devolo_home_control
devolo_home_network
dexcom
dhcp
diagnostics
dialogflow
diaz
digital_loggers
digital_ocean
directv
discogs
discord
discovergy
dlib_face_detect
dlib_face_identify
dlink
dlna_dmr
dlna_dms
dnsip
dominos
doods
doorbird
dooya
dormakaba_dkey
dovado
downloader
dremel_3d_printer
drop_connect
dsmr
dsmr_reader
dte_energy_bridge
dublin_bus_transport
duckdns
dunehd
duotecno
duquesne_light
dwd_weather_warnings
dweet
dynalite
eafm
eastron
easyenergy
ebox
ebusd
ecoal_boiler
ecobee
ecoforest
econet
ecovacs
ecowitt
eddystone_temperature
edimax
edl21
efergy
egardia
eight_sleep
electrasmart
electric_kiwi
elgato
eliqonline
elkm1
elmax
elv
elvia
emby
emoncms
emoncms_history
emonitor
emulated_hue
emulated_kasa
emulated_roku
energenie_power_sockets
energie_vanons
energy
energyzero
enigma2
enmax
enocean
enphase_envoy
entur_public_transport
environment_canada
envisalink
ephember
epic_games_store
epion
epson
eq3btsmart
escea
esera_onewire
esphome
etherscan
eufy
eufylife_ble
event
evergy
everlights
evil_genius_labs
evohome
ezviz
faa_delays
facebook
fail2ban
familyhub
fan
fastdotcom
feedreader
ffmpeg
ffmpeg_motion
ffmpeg_noise
fibaro
fido
file
file_upload
filesize
filter
fints
fire_tv
fireservicerota
firmata
fitbit
fivem
fixer
fjaraskupan
fleetgo
flexit
flexit_bacnet
flexom
flic
flick_electric
flipr
flo
flock
flume
flux
flux_led
folder
folder_watcher
foobot
forecast_solar
forked_daapd
fortios
foscam
foursquare
free_mobile
freebox
freedns
freedompro
fritz
fritzbox
fritzbox_callmonitor
fronius
frontend
frontier_silicon
fujitsu_anywair
fully_kiosk
futurenow
fyta
garadget
garages_amsterdam
gardena_bluetooth
gaviota
gc100
gdacs
generic
generic_hygrostat
generic_thermostat
geniushub
geo_json_events
geo_location
geo_rss_events
geocaching
geofency
geonetnz_quakes
geonetnz_volcano
gios
github
gitlab_ci
gitter
glances
goalzero
gogogate2
goodwe
google
google_assistant
google_assistant_sdk
google_cloud
google_domains
google_generative_ai_conversation
google_mail
google_maps
google_pubsub
google_sheets
google_tasks
google_translate
google_travel_time
google_wifi
govee_ble
govee_light_local
gpsd
gpslogger
graphite
gree
greeneye_monitor
greenwave
group
growatt_server
gstreamer
gtfs
guardian
habitica
hardkernel
hardware
harman_kardon_avr
harmony
hassio
havana_shade
haveibeenpwned
hddtemp
hdmi_cec
heatmiser
heiwa
heos
here_travel_time
hexaom
hi_kumo
hikvision
hikvisioncam
hisense_aehw4a1
history
history_stats
hitron_coda
hive
hko
hlk_sw16
holiday
home_connect
home_plus_control
homeassistant
homeassistant_alerts
homeassistant_green
homeassistant_hardware
homeassistant_sky_connect
homeassistant_yellow
homekit
homekit_controller
homematic
homematicip_cloud
homewizard
homeworks
honeywell
horizon
hp_ilo
html5
http
huawei_lte
hue
huisbaasje
humidifier
hunterdouglas_powerview
hurrican_shutters_wholesale
husqvarna_automower
huum
hvv_departures
hydrawise
hyperion
ialarm
iammeter
iaqualink
ibeacon
icloud
idasen_desk
idteck_prox
ifttt
iglo
ign_sismologia
ihc
image
image_processing
image_upload
imap
imgw_pib
improv_ble
incomfort
indianamichiganpower
influxdb
inkbird
input_boolean
input_button
input_datetime
input_number
input_select
input_text
inspired_shades
insteon
integration
intellifire
intent
intent_script
intesishome
ios
iotawatt
iperf3
ipma
ipp
iqvia
irish_rail_transport
isal
islamic_prayer_times
ismartwindow
iss
isy994
itach
itunes
izone
jellyfin
jewish_calendar
joaoapps_join
juicenet
justnimbus
jvc_projector
kaiterra
kaleidescape
kankun
keba
keenetic_ndms2
kef
kegtron
kentuckypower
keyboard
keyboard_remote
keymitt_ble
kira
kitchen_sink
kiwi
kmtronic
knx
kodi
konnected
kostal_plenticore
kraken
krispol
kulersky
kwb
lacrosse
lacrosse_view
lamarzocco
lametric
landisgyr_heat_meter
lannouncer
lastfm
launch_library
laundrify
lawn_mower
lcn
ld2410_ble
leaone
led_ble
legrand
lg_netcast
lg_soundbar
lidarr
life360
lifx
lifx_cloud
light
lightwave
limitlessled
linear_garage_door
linksys_smart
linode
linux_battery
lirc
litejet
litterrobot
livisi
llamalab_automate
local_calendar
local_file
local_ip
local_todo
locative
lock
logbook
logentries
logger
logi_circle
london_air
london_underground
lookin
loqed
lovelace
luci
luftdaten
lupusec
lutron
lutron_caseta
luxaflex
lw12wifi
lyric
madeco
mailbox
mailgun
manual
manual_mqtt
map
marantz
martec
marytts
mastodon
matrix
matter
maxcube
mazda
meater
medcom_ble
media_extractor
media_player
media_source
mediaroom
melcloud
melissa
melnor
meraki
message_bird
met
met_eireann
meteo_france
meteoalarm
meteoclimatic
metoffice
mfi
microbees
microsoft
microsoft_face
microsoft_face_detect
microsoft_face_identify
mijndomein_energie
mikrotik
mill
min_max
minecraft_server
minio
mjpeg
moat
mobile_app
mochad
modbus
modem_callerid
modern_forms
moehlenhoff_alpha2
mold_indicator
monessen
monoprice
monzo
moon
mopeka
motion_blinds
motionblinds_ble
motioneye
motionmount
mpd
mqtt
mqtt_eventstream
mqtt_json
mqtt_room
mqtt_statestream
msteams
mullvad
mutesync
mvglive
my
mycroft
myq
mysensors
mystrom
mythicbeastsdns
myuplink
nad
nam
namecheapdns
nanoleaf
neato
nederlandse_spoorwegen
ness_alarm
nest
netatmo
netdata
netgear
netgear_lte
netio
network
neurio_energy
nexia
nexity
nextbus
nextcloud
nextdns
nfandroidtv
nibe_heatpump
nightscout
niko_home_control
nilu
nina
nissan_leaf
nmap_tracker
nmbs
no_ip
noaa_tides
nobo_hub
norway_air
notify
notify_events
notion
nsw_fuel_station
nsw_rural_fire_service_feed
nuheat
nuki
numato
number
nut
nutrichef
nws
nx584
nzbget
oasa_telematics
obihai
octoprint
oem
ohmconnect
ollama
ombi
omnilogic
onboarding
oncue
ondilo_ico
onewire
onkyo
onvif
open_meteo
openai_conversation
openalpr_cloud
openerz
openevse
openexchangerates
opengarage
openhardwaremonitor
openhome
opensensemap
opensky
opentherm_gw
openuv
openweathermap
opnsense
opower
opple
oralb
oru
oru_opower
orvibo
osoenergy
osramlightify
otbr
otp
ourgroceries
overkiz
ovo_energy
owntracks
p1_monitor
panasonic_bluray
panasonic_viera
pandora
panel_custom
panel_iframe
pcs_lighting
peco
peco_opower
pegel_online
pencom
pepco
permobil
persistent_notification
person
pge
philips_js
pi_hole
picnic
picotts
pilight
ping
pioneer
piper
pjlink
plaato
plant
plex
plugwise
plum_lightpad
pocketcasts
point
poolsense
portlandgeneral
powerwall
private_ble_device
profiler
progettihwsw
proliphix
prometheus
prosegur
prowl
proximity
proxmoxve
proxy
prusalink
ps4
pse
psoklahoma
pulseaudio_loopback
pure_energie
purpleair
push
pushbullet
pushover
pushsafer
pvoutput
pvpc_hourly_pricing
pyload
python_script
qbittorrent
qingping
qld_bushfire
qnap
qnap_qsw
qrcode
quadrafire
quantum_gateway
qvr_pro
qwikswitch
rabbitair
rachio
radarr
radio_browser
radiotherm
rainbird
raincloud
rainforest_eagle
rainforest_raven
rainmachine
random
rapt_ble
raspberry_pi
raspyrfm
raven_rock_mfg
rdw
recollect_waste
recorder
recovery_mode
recswitch
reddit
refoss
rejseplanen
remember_the_milk
remote
remote_rpi_gpio
renault
renson
reolink
repairs
repetier
rest
rest_command
rexel
rflink
rfxtrx
rhasspy
ridwell
ring
ripple
risco
rituals_perfume_genie
rmvtransport
roborock
rocketchat
roku
romy
roomba
roon
route53
rova
rpi_camera
rpi_power
rss_feed_template
rtorrent
rtsp_to_webrtc
ruckus_unleashed
russound_rio
russound_rnet
ruuvi_gateway
ruuvitag_ble
rympro
sabnzbd
saj
samsam
samsungtv
sanix
satel_integra
scene
schedule
schlage
schluter
scl
scrape
screenaway
screenlogic
script
scsgate
search
season
select
sendgrid
sense
sensibo
sensirion_ble
sensor
sensorblue
sensorpro
sensorpush
sentry
senz
serial
serial_pm
sesame
seven_segments
seventeentrack
sfr_box
sharkiq
shell_command
shelly
shodan
shopping_list
sia
sigfox
sighthound
signal_messenger
simplepush
simplisafe
simply_automated
simu
simulated
sinch
siren
sisyphus
sky_hub
skybeacon
skybell
slack
sleepiq
slide
slimproto
sma
smappee
smart_blinds
smart_home
smart_meter_texas
smarther
smartthings
smarttub
smarty
smhi
sms
smtp
smud
snapcast
snips
snmp
snooz
solaredge
solaredge_local
solarlog
solax
soma
somfy
somfy_mylink
sonarr
songpal
sonos
sony_projector
soundtouch
spaceapi
spc
speedtestdotnet
spider
splunk
spotify
sql
squeezebox
srp_energy
ssdp
starline
starlingbank
starlink
startca
statistics
statsd
steam_online
steamist
stiebel_eltron
stookalert
stookwijzer
stream
streamlabswater
stt
subaru
suez_water
sun
sunweg
supervisord
supla
surepetcare
swepco
swiss_hydrological_data
swiss_public_transport
swisscom
switch
switch_as_x
switchbee
switchbot
switchbot_cloud
switcher_kis
switchmate
symfonisk
syncthing
syncthru
synology_chat
synology_dsm
synology_srm
syslog
system_bridge
system_health
system_log
systemmonitor
tado
tag
tailscale
tailwind
tami4
tank_utility
tankerkoenig
tapsaff
tasmota
tautulli
tcp
technove
ted5000
tedee
telegram
telegram_bot
tellduslive
tellstick
telnet
temper
template
tensorflow
__init__.py
image_processing.py
manifest.json
tesla_wall_connector
teslemetry
tessie
text
tfiac
thermobeacon
thermoplus
thermopro
thermoworks_smoke
thethingsnetwork
thingspeak
thinkingcleaner
thomson
thread
threshold
tibber
tikteck
tile
tilt_ble
time
time_date
timer
tmb
tod
todo
todoist
tolo
tomato
tomorrowio
toon
torque
totalconnect
touchline
tplink
tplink_lte
tplink_omada
tplink_tapo
traccar
traccar_server
trace
tractive
tradfri
trafikverket_camera
trafikverket_ferry
trafikverket_train
trafikverket_weatherstation
transmission
transport_nsw
travisci
trend
tts
tuya
twentemilieu
twilio
twilio_call
twilio_sms
twinkly
twitch
twitter
ubiwizz
ubus
ue_smart_radio
uk_transport
ukraine_alarm
ultraloq
unifi
unifi_direct
unifiled
unifiprotect
universal
upb
upc_connect
upcloud
update
upnp
uprise_smart_shades
uptime
uptimerobot
usb
usgs_earthquakes_feed
utility_meter
uvc
v2c
vacuum
vallox
valve
vasttrafik
velbus
velux
venstar
vera
verisure
vermont_castings
versasense
version
vesync
viaggiatreno
vicare
vilfo
vivotek
vizio
vlc
vlc_telnet
vodafone_station
voicerss
voip
volkszaehler
volumio
volvooncall
vulcan
vultr
w800rf32
wake_on_lan
wake_word
wallbox
waqi
water_heater
waterfurnace
watson_iot
watson_tts
watttime
waze_travel_time
weather
weatherflow
weatherflow_cloud
weatherkit
webhook
webmin
webostv
websocket_api
wemo
whirlpool
whisper
whois
wiffi
wilight
wirelesstag
withings
wiz
wled
wolflink
workday
worldclock
worldtidesinfo
worxlandroid
ws66i
wsdot
wyoming
x10
xbox
xeoma
xiaomi
xiaomi_aqara
xiaomi_ble
xiaomi_miio
xiaomi_tv
xmpp
xs1
yale_home
yale_smart_alarm
yalexs_ble
yamaha
yamaha_musiccast
yandex_transport
yandextts
yardian
yeelight
yeelightsunflower
yi
yolink
youless
youtube
zabbix
zamg
zengge
zeroconf
zerproc
zestimate
zeversolar
zha
zhong_hong
ziggo_mediabox_xl
zodiac
zondergas
zone
zoneminder
zwave_js
zwave_me
__init__.py
generated
helpers
scripts
util
__init__.py
__main__.py
block_async_io.py
bootstrap.py
config.py
config_entries.py
const.py
core.py
data_entry_flow.py
exceptions.py
loader.py
package_constraints.txt
py.typed
requirements.py
runner.py
setup.py
strings.json
machine
pylint
rootfs
script
tests
.core_files.yaml
.coveragerc
.dockerignore
.git-blame-ignore-revs
.gitattributes
.gitignore
.hadolint.yaml
.pre-commit-config.yaml
.prettierignore
.strict-typing
.yamllint
CLA.md
CODEOWNERS
CODE_OF_CONDUCT.md
CONTRIBUTING.md
Dockerfile
Dockerfile.dev
LICENSE.md
MANIFEST.in
README.rst
build.yaml
codecov.yml
mypy.ini
pyproject.toml
requirements.txt
requirements_all.txt
requirements_test.txt
requirements_test_all.txt
requirements_test_pre_commit.txt
432 lines
15 KiB
Python
432 lines
15 KiB
Python
"""Support for performing TensorFlow classification on images."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import io
|
|
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
|
|
import numpy as np
|
|
from PIL import Image, ImageDraw, UnidentifiedImageError
|
|
import tensorflow as tf
|
|
import voluptuous as vol
|
|
|
|
from homeassistant.components.image_processing import (
|
|
CONF_CONFIDENCE,
|
|
PLATFORM_SCHEMA,
|
|
ImageProcessingEntity,
|
|
)
|
|
from homeassistant.const import (
|
|
CONF_ENTITY_ID,
|
|
CONF_MODEL,
|
|
CONF_NAME,
|
|
CONF_SOURCE,
|
|
EVENT_HOMEASSISTANT_START,
|
|
)
|
|
from homeassistant.core import HomeAssistant, split_entity_id
|
|
from homeassistant.helpers import template
|
|
import homeassistant.helpers.config_validation as cv
|
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
|
from homeassistant.util.pil import draw_box
|
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
|
|
|
DOMAIN = "tensorflow"
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
ATTR_MATCHES = "matches"
|
|
ATTR_SUMMARY = "summary"
|
|
ATTR_TOTAL_MATCHES = "total_matches"
|
|
ATTR_PROCESS_TIME = "process_time"
|
|
|
|
CONF_AREA = "area"
|
|
CONF_BOTTOM = "bottom"
|
|
CONF_CATEGORIES = "categories"
|
|
CONF_CATEGORY = "category"
|
|
CONF_FILE_OUT = "file_out"
|
|
CONF_GRAPH = "graph"
|
|
CONF_LABELS = "labels"
|
|
CONF_LABEL_OFFSET = "label_offset"
|
|
CONF_LEFT = "left"
|
|
CONF_MODEL_DIR = "model_dir"
|
|
CONF_RIGHT = "right"
|
|
CONF_TOP = "top"
|
|
|
|
AREA_SCHEMA = vol.Schema(
|
|
{
|
|
vol.Optional(CONF_BOTTOM, default=1): cv.small_float,
|
|
vol.Optional(CONF_LEFT, default=0): cv.small_float,
|
|
vol.Optional(CONF_RIGHT, default=1): cv.small_float,
|
|
vol.Optional(CONF_TOP, default=0): cv.small_float,
|
|
}
|
|
)
|
|
|
|
CATEGORY_SCHEMA = vol.Schema(
|
|
{vol.Required(CONF_CATEGORY): cv.string, vol.Optional(CONF_AREA): AREA_SCHEMA}
|
|
)
|
|
|
|
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
|
{
|
|
vol.Optional(CONF_FILE_OUT, default=[]): vol.All(cv.ensure_list, [cv.template]),
|
|
vol.Required(CONF_MODEL): vol.Schema(
|
|
{
|
|
vol.Required(CONF_GRAPH): cv.isdir,
|
|
vol.Optional(CONF_AREA): AREA_SCHEMA,
|
|
vol.Optional(CONF_CATEGORIES, default=[]): vol.All(
|
|
cv.ensure_list, [vol.Any(cv.string, CATEGORY_SCHEMA)]
|
|
),
|
|
vol.Optional(CONF_LABELS): cv.isfile,
|
|
vol.Optional(CONF_LABEL_OFFSET, default=1): int,
|
|
vol.Optional(CONF_MODEL_DIR): cv.isdir,
|
|
}
|
|
),
|
|
}
|
|
)
|
|
|
|
|
|
def get_model_detection_function(model):
|
|
"""Get a tf.function for detection."""
|
|
|
|
@tf.function
|
|
def detect_fn(image):
|
|
"""Detect objects in image."""
|
|
|
|
image, shapes = model.preprocess(image)
|
|
prediction_dict = model.predict(image, shapes)
|
|
return model.postprocess(prediction_dict, shapes)
|
|
|
|
return detect_fn
|
|
|
|
|
|
def setup_platform(
|
|
hass: HomeAssistant,
|
|
config: ConfigType,
|
|
add_entities: AddEntitiesCallback,
|
|
discovery_info: DiscoveryInfoType | None = None,
|
|
) -> None:
|
|
"""Set up the TensorFlow image processing platform."""
|
|
model_config = config[CONF_MODEL]
|
|
model_dir = model_config.get(CONF_MODEL_DIR) or hass.config.path("tensorflow")
|
|
labels = model_config.get(CONF_LABELS) or hass.config.path(
|
|
"tensorflow", "object_detection", "data", "mscoco_label_map.pbtxt"
|
|
)
|
|
checkpoint = os.path.join(model_config[CONF_GRAPH], "checkpoint")
|
|
pipeline_config = os.path.join(model_config[CONF_GRAPH], "pipeline.config")
|
|
|
|
# Make sure locations exist
|
|
if (
|
|
not os.path.isdir(model_dir)
|
|
or not os.path.isdir(checkpoint)
|
|
or not os.path.exists(pipeline_config)
|
|
or not os.path.exists(labels)
|
|
):
|
|
_LOGGER.error("Unable to locate tensorflow model or label map")
|
|
return
|
|
|
|
# append custom model path to sys.path
|
|
sys.path.append(model_dir)
|
|
|
|
try:
|
|
# Verify that the TensorFlow Object Detection API is pre-installed
|
|
# These imports shouldn't be moved to the top, because they depend on code from the model_dir.
|
|
# (The model_dir is created during the manual setup process. See integration docs.)
|
|
|
|
# pylint: disable=import-outside-toplevel
|
|
from object_detection.builders import model_builder
|
|
from object_detection.utils import config_util, label_map_util
|
|
except ImportError:
|
|
_LOGGER.error(
|
|
"No TensorFlow Object Detection library found! Install or compile "
|
|
"for your system following instructions here: "
|
|
"https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2.md#installation"
|
|
)
|
|
return
|
|
|
|
try:
|
|
# Display warning that PIL will be used if no OpenCV is found.
|
|
import cv2 # noqa: F401 pylint: disable=import-outside-toplevel
|
|
except ImportError:
|
|
_LOGGER.warning(
|
|
"No OpenCV library found. TensorFlow will process image with "
|
|
"PIL at reduced resolution"
|
|
)
|
|
|
|
hass.data[DOMAIN] = {CONF_MODEL: None}
|
|
|
|
def tensorflow_hass_start(_event):
|
|
"""Set up TensorFlow model on hass start."""
|
|
start = time.perf_counter()
|
|
|
|
# Load pipeline config and build a detection model
|
|
pipeline_configs = config_util.get_configs_from_pipeline_file(pipeline_config)
|
|
detection_model = model_builder.build(
|
|
model_config=pipeline_configs["model"], is_training=False
|
|
)
|
|
|
|
# Restore checkpoint
|
|
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
|
|
ckpt.restore(os.path.join(checkpoint, "ckpt-0")).expect_partial()
|
|
|
|
_LOGGER.debug(
|
|
"Model checkpoint restore took %d seconds", time.perf_counter() - start
|
|
)
|
|
|
|
model = get_model_detection_function(detection_model)
|
|
|
|
# Preload model cache with empty image tensor
|
|
inp = np.zeros([2160, 3840, 3], dtype=np.uint8)
|
|
# The input needs to be a tensor, convert it using `tf.convert_to_tensor`.
|
|
input_tensor = tf.convert_to_tensor(inp, dtype=tf.float32)
|
|
# The model expects a batch of images, so add an axis with `tf.newaxis`.
|
|
input_tensor = input_tensor[tf.newaxis, ...]
|
|
# Run inference
|
|
model(input_tensor)
|
|
|
|
_LOGGER.debug("Model load took %d seconds", time.perf_counter() - start)
|
|
hass.data[DOMAIN][CONF_MODEL] = model
|
|
|
|
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, tensorflow_hass_start)
|
|
|
|
category_index = label_map_util.create_category_index_from_labelmap(
|
|
labels, use_display_name=True
|
|
)
|
|
|
|
add_entities(
|
|
TensorFlowImageProcessor(
|
|
hass,
|
|
camera[CONF_ENTITY_ID],
|
|
camera.get(CONF_NAME),
|
|
category_index,
|
|
config,
|
|
)
|
|
for camera in config[CONF_SOURCE]
|
|
)
|
|
|
|
|
|
class TensorFlowImageProcessor(ImageProcessingEntity):
|
|
"""Representation of an TensorFlow image processor."""
|
|
|
|
def __init__(
|
|
self,
|
|
hass,
|
|
camera_entity,
|
|
name,
|
|
category_index,
|
|
config,
|
|
):
|
|
"""Initialize the TensorFlow entity."""
|
|
model_config = config.get(CONF_MODEL)
|
|
self.hass = hass
|
|
self._camera_entity = camera_entity
|
|
if name:
|
|
self._name = name
|
|
else:
|
|
self._name = f"TensorFlow {split_entity_id(camera_entity)[1]}"
|
|
self._category_index = category_index
|
|
self._min_confidence = config.get(CONF_CONFIDENCE)
|
|
self._file_out = config.get(CONF_FILE_OUT)
|
|
|
|
# handle categories and specific detection areas
|
|
self._label_id_offset = model_config.get(CONF_LABEL_OFFSET)
|
|
categories = model_config.get(CONF_CATEGORIES)
|
|
self._include_categories = []
|
|
self._category_areas = {}
|
|
for category in categories:
|
|
if isinstance(category, dict):
|
|
category_name = category.get(CONF_CATEGORY)
|
|
category_area = category.get(CONF_AREA)
|
|
self._include_categories.append(category_name)
|
|
self._category_areas[category_name] = [0, 0, 1, 1]
|
|
if category_area:
|
|
self._category_areas[category_name] = [
|
|
category_area.get(CONF_TOP),
|
|
category_area.get(CONF_LEFT),
|
|
category_area.get(CONF_BOTTOM),
|
|
category_area.get(CONF_RIGHT),
|
|
]
|
|
else:
|
|
self._include_categories.append(category)
|
|
self._category_areas[category] = [0, 0, 1, 1]
|
|
|
|
# Handle global detection area
|
|
self._area = [0, 0, 1, 1]
|
|
if area_config := model_config.get(CONF_AREA):
|
|
self._area = [
|
|
area_config.get(CONF_TOP),
|
|
area_config.get(CONF_LEFT),
|
|
area_config.get(CONF_BOTTOM),
|
|
area_config.get(CONF_RIGHT),
|
|
]
|
|
|
|
template.attach(hass, self._file_out)
|
|
|
|
self._matches = {}
|
|
self._total_matches = 0
|
|
self._last_image = None
|
|
self._process_time = 0
|
|
|
|
@property
|
|
def camera_entity(self):
|
|
"""Return camera entity id from process pictures."""
|
|
return self._camera_entity
|
|
|
|
@property
|
|
def name(self):
|
|
"""Return the name of the image processor."""
|
|
return self._name
|
|
|
|
@property
|
|
def state(self):
|
|
"""Return the state of the entity."""
|
|
return self._total_matches
|
|
|
|
@property
|
|
def extra_state_attributes(self):
|
|
"""Return device specific state attributes."""
|
|
return {
|
|
ATTR_MATCHES: self._matches,
|
|
ATTR_SUMMARY: {
|
|
category: len(values) for category, values in self._matches.items()
|
|
},
|
|
ATTR_TOTAL_MATCHES: self._total_matches,
|
|
ATTR_PROCESS_TIME: self._process_time,
|
|
}
|
|
|
|
def _save_image(self, image, matches, paths):
|
|
img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
|
|
img_width, img_height = img.size
|
|
draw = ImageDraw.Draw(img)
|
|
|
|
# Draw custom global region/area
|
|
if self._area != [0, 0, 1, 1]:
|
|
draw_box(
|
|
draw, self._area, img_width, img_height, "Detection Area", (0, 255, 255)
|
|
)
|
|
|
|
for category, values in matches.items():
|
|
# Draw custom category regions/areas
|
|
if category in self._category_areas and self._category_areas[category] != [
|
|
0,
|
|
0,
|
|
1,
|
|
1,
|
|
]:
|
|
label = f"{category.capitalize()} Detection Area"
|
|
draw_box(
|
|
draw,
|
|
self._category_areas[category],
|
|
img_width,
|
|
img_height,
|
|
label,
|
|
(0, 255, 0),
|
|
)
|
|
|
|
# Draw detected objects
|
|
for instance in values:
|
|
label = "{} {:.1f}%".format(category, instance["score"])
|
|
draw_box(
|
|
draw, instance["box"], img_width, img_height, label, (255, 255, 0)
|
|
)
|
|
|
|
for path in paths:
|
|
_LOGGER.info("Saving results image to %s", path)
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
img.save(path)
|
|
|
|
def process_image(self, image):
|
|
"""Process the image."""
|
|
if not (model := self.hass.data[DOMAIN][CONF_MODEL]):
|
|
_LOGGER.debug("Model not yet ready")
|
|
return
|
|
|
|
start = time.perf_counter()
|
|
try:
|
|
import cv2 # pylint: disable=import-outside-toplevel
|
|
|
|
img = cv2.imdecode(np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED)
|
|
inp = img[:, :, [2, 1, 0]] # BGR->RGB
|
|
inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3)
|
|
except ImportError:
|
|
try:
|
|
img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
|
|
except UnidentifiedImageError:
|
|
_LOGGER.warning("Unable to process image, bad data")
|
|
return
|
|
img.thumbnail((460, 460), Image.ANTIALIAS)
|
|
img_width, img_height = img.size
|
|
inp = (
|
|
np.array(img.getdata())
|
|
.reshape((img_height, img_width, 3))
|
|
.astype(np.uint8)
|
|
)
|
|
inp_expanded = np.expand_dims(inp, axis=0)
|
|
|
|
# The input needs to be a tensor, convert it using `tf.convert_to_tensor`.
|
|
input_tensor = tf.convert_to_tensor(inp_expanded, dtype=tf.float32)
|
|
|
|
detections = model(input_tensor)
|
|
boxes = detections["detection_boxes"][0].numpy()
|
|
scores = detections["detection_scores"][0].numpy()
|
|
classes = (
|
|
detections["detection_classes"][0].numpy() + self._label_id_offset
|
|
).astype(int)
|
|
|
|
matches = {}
|
|
total_matches = 0
|
|
for box, score, obj_class in zip(boxes, scores, classes, strict=False):
|
|
score = score * 100
|
|
boxes = box.tolist()
|
|
|
|
# Exclude matches below min confidence value
|
|
if score < self._min_confidence:
|
|
continue
|
|
|
|
# Exclude matches outside global area definition
|
|
if (
|
|
boxes[0] < self._area[0]
|
|
or boxes[1] < self._area[1]
|
|
or boxes[2] > self._area[2]
|
|
or boxes[3] > self._area[3]
|
|
):
|
|
continue
|
|
|
|
category = self._category_index[obj_class]["name"]
|
|
|
|
# Exclude unlisted categories
|
|
if self._include_categories and category not in self._include_categories:
|
|
continue
|
|
|
|
# Exclude matches outside category specific area definition
|
|
if self._category_areas and (
|
|
boxes[0] < self._category_areas[category][0]
|
|
or boxes[1] < self._category_areas[category][1]
|
|
or boxes[2] > self._category_areas[category][2]
|
|
or boxes[3] > self._category_areas[category][3]
|
|
):
|
|
continue
|
|
|
|
# If we got here, we should include it
|
|
if category not in matches:
|
|
matches[category] = []
|
|
matches[category].append({"score": float(score), "box": boxes})
|
|
total_matches += 1
|
|
|
|
# Save Images
|
|
if total_matches and self._file_out:
|
|
paths = []
|
|
for path_template in self._file_out:
|
|
if isinstance(path_template, template.Template):
|
|
paths.append(
|
|
path_template.render(camera_entity=self._camera_entity)
|
|
)
|
|
else:
|
|
paths.append(path_template)
|
|
self._save_image(image, matches, paths)
|
|
|
|
self._matches = matches
|
|
self._total_matches = total_matches
|
|
self._process_time = time.perf_counter() - start
|