mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 07:37:34 +00:00
Add facebox auth (#15439)
* Adds auth * Update facebox.py * Update test_facebox.py * Update facebox.py * Update facebox.py * Update facebox.py * Update facebox.py * Remove TIMEOUT * Update test_facebox.py * fix lint * Update facebox.py * Update test_facebox.py * Update facebox.py * Adds check_box_health * Adds test auth * Update test_facebox.py * Update test_facebox.py * Update test_facebox.py * Update test_facebox.py * Ups coverage * Update test_facebox.py * Update facebox.py * Update test_facebox.py * Update facebox.py * Update test_facebox.py * Update facebox.py * Update facebox.py * Update facebox.py
This commit is contained in:
parent
47fa928425
commit
61721478f3
@ -17,25 +17,29 @@ import homeassistant.helpers.config_validation as cv
|
|||||||
from homeassistant.components.image_processing import (
|
from homeassistant.components.image_processing import (
|
||||||
PLATFORM_SCHEMA, ImageProcessingFaceEntity, ATTR_CONFIDENCE, CONF_SOURCE,
|
PLATFORM_SCHEMA, ImageProcessingFaceEntity, ATTR_CONFIDENCE, CONF_SOURCE,
|
||||||
CONF_ENTITY_ID, CONF_NAME, DOMAIN)
|
CONF_ENTITY_ID, CONF_NAME, DOMAIN)
|
||||||
from homeassistant.const import (CONF_IP_ADDRESS, CONF_PORT)
|
from homeassistant.const import (
|
||||||
|
CONF_IP_ADDRESS, CONF_PORT, CONF_PASSWORD, CONF_USERNAME,
|
||||||
|
HTTP_BAD_REQUEST, HTTP_OK, HTTP_UNAUTHORIZED)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
ATTR_BOUNDING_BOX = 'bounding_box'
|
ATTR_BOUNDING_BOX = 'bounding_box'
|
||||||
ATTR_CLASSIFIER = 'classifier'
|
ATTR_CLASSIFIER = 'classifier'
|
||||||
ATTR_IMAGE_ID = 'image_id'
|
ATTR_IMAGE_ID = 'image_id'
|
||||||
|
ATTR_ID = 'id'
|
||||||
ATTR_MATCHED = 'matched'
|
ATTR_MATCHED = 'matched'
|
||||||
|
FACEBOX_NAME = 'name'
|
||||||
CLASSIFIER = 'facebox'
|
CLASSIFIER = 'facebox'
|
||||||
DATA_FACEBOX = 'facebox_classifiers'
|
DATA_FACEBOX = 'facebox_classifiers'
|
||||||
EVENT_CLASSIFIER_TEACH = 'image_processing.teach_classifier'
|
|
||||||
FILE_PATH = 'file_path'
|
FILE_PATH = 'file_path'
|
||||||
SERVICE_TEACH_FACE = 'facebox_teach_face'
|
SERVICE_TEACH_FACE = 'facebox_teach_face'
|
||||||
TIMEOUT = 9
|
|
||||||
|
|
||||||
|
|
||||||
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
|
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
|
||||||
vol.Required(CONF_IP_ADDRESS): cv.string,
|
vol.Required(CONF_IP_ADDRESS): cv.string,
|
||||||
vol.Required(CONF_PORT): cv.port,
|
vol.Required(CONF_PORT): cv.port,
|
||||||
|
vol.Optional(CONF_USERNAME): cv.string,
|
||||||
|
vol.Optional(CONF_PASSWORD): cv.string,
|
||||||
})
|
})
|
||||||
|
|
||||||
SERVICE_TEACH_SCHEMA = vol.Schema({
|
SERVICE_TEACH_SCHEMA = vol.Schema({
|
||||||
@ -45,6 +49,26 @@ SERVICE_TEACH_SCHEMA = vol.Schema({
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def check_box_health(url, username, password):
|
||||||
|
"""Check the health of the classifier and return its id if healthy."""
|
||||||
|
kwargs = {}
|
||||||
|
if username:
|
||||||
|
kwargs['auth'] = requests.auth.HTTPBasicAuth(username, password)
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
url,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
if response.status_code == HTTP_UNAUTHORIZED:
|
||||||
|
_LOGGER.error("AuthenticationError on %s", CLASSIFIER)
|
||||||
|
return None
|
||||||
|
if response.status_code == HTTP_OK:
|
||||||
|
return response.json()['hostname']
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
_LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def encode_image(image):
|
def encode_image(image):
|
||||||
"""base64 encode an image stream."""
|
"""base64 encode an image stream."""
|
||||||
base64_img = base64.b64encode(image).decode('ascii')
|
base64_img = base64.b64encode(image).decode('ascii')
|
||||||
@ -63,10 +87,10 @@ def parse_faces(api_faces):
|
|||||||
for entry in api_faces:
|
for entry in api_faces:
|
||||||
face = {}
|
face = {}
|
||||||
if entry['matched']: # This data is only in matched faces.
|
if entry['matched']: # This data is only in matched faces.
|
||||||
face[ATTR_NAME] = entry['name']
|
face[FACEBOX_NAME] = entry['name']
|
||||||
face[ATTR_IMAGE_ID] = entry['id']
|
face[ATTR_IMAGE_ID] = entry['id']
|
||||||
else: # Lets be explicit.
|
else: # Lets be explicit.
|
||||||
face[ATTR_NAME] = None
|
face[FACEBOX_NAME] = None
|
||||||
face[ATTR_IMAGE_ID] = None
|
face[ATTR_IMAGE_ID] = None
|
||||||
face[ATTR_CONFIDENCE] = round(100.0*entry['confidence'], 2)
|
face[ATTR_CONFIDENCE] = round(100.0*entry['confidence'], 2)
|
||||||
face[ATTR_MATCHED] = entry['matched']
|
face[ATTR_MATCHED] = entry['matched']
|
||||||
@ -75,17 +99,46 @@ def parse_faces(api_faces):
|
|||||||
return known_faces
|
return known_faces
|
||||||
|
|
||||||
|
|
||||||
def post_image(url, image):
|
def post_image(url, image, username, password):
|
||||||
"""Post an image to the classifier."""
|
"""Post an image to the classifier."""
|
||||||
|
kwargs = {}
|
||||||
|
if username:
|
||||||
|
kwargs['auth'] = requests.auth.HTTPBasicAuth(username, password)
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
json={"base64": encode_image(image)},
|
json={"base64": encode_image(image)},
|
||||||
timeout=TIMEOUT
|
**kwargs
|
||||||
)
|
)
|
||||||
|
if response.status_code == HTTP_UNAUTHORIZED:
|
||||||
|
_LOGGER.error("AuthenticationError on %s", CLASSIFIER)
|
||||||
|
return None
|
||||||
return response
|
return response
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
_LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER)
|
_LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def teach_file(url, name, file_path, username, password):
|
||||||
|
"""Teach the classifier a name associated with a file."""
|
||||||
|
kwargs = {}
|
||||||
|
if username:
|
||||||
|
kwargs['auth'] = requests.auth.HTTPBasicAuth(username, password)
|
||||||
|
try:
|
||||||
|
with open(file_path, 'rb') as open_file:
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
data={FACEBOX_NAME: name, ATTR_ID: file_path},
|
||||||
|
files={'file': open_file},
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
if response.status_code == HTTP_UNAUTHORIZED:
|
||||||
|
_LOGGER.error("AuthenticationError on %s", CLASSIFIER)
|
||||||
|
elif response.status_code == HTTP_BAD_REQUEST:
|
||||||
|
_LOGGER.error("%s teaching of file %s failed with message:%s",
|
||||||
|
CLASSIFIER, file_path, response.text)
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
_LOGGER.error("ConnectionError: Is %s running?", CLASSIFIER)
|
||||||
|
|
||||||
|
|
||||||
def valid_file_path(file_path):
|
def valid_file_path(file_path):
|
||||||
@ -104,13 +157,20 @@ def setup_platform(hass, config, add_devices, discovery_info=None):
|
|||||||
if DATA_FACEBOX not in hass.data:
|
if DATA_FACEBOX not in hass.data:
|
||||||
hass.data[DATA_FACEBOX] = []
|
hass.data[DATA_FACEBOX] = []
|
||||||
|
|
||||||
|
ip_address = config[CONF_IP_ADDRESS]
|
||||||
|
port = config[CONF_PORT]
|
||||||
|
username = config.get(CONF_USERNAME)
|
||||||
|
password = config.get(CONF_PASSWORD)
|
||||||
|
url_health = "http://{}:{}/healthz".format(ip_address, port)
|
||||||
|
hostname = check_box_health(url_health, username, password)
|
||||||
|
if hostname is None:
|
||||||
|
return
|
||||||
|
|
||||||
entities = []
|
entities = []
|
||||||
for camera in config[CONF_SOURCE]:
|
for camera in config[CONF_SOURCE]:
|
||||||
facebox = FaceClassifyEntity(
|
facebox = FaceClassifyEntity(
|
||||||
config[CONF_IP_ADDRESS],
|
ip_address, port, username, password, hostname,
|
||||||
config[CONF_PORT],
|
camera[CONF_ENTITY_ID], camera.get(CONF_NAME))
|
||||||
camera[CONF_ENTITY_ID],
|
|
||||||
camera.get(CONF_NAME))
|
|
||||||
entities.append(facebox)
|
entities.append(facebox)
|
||||||
hass.data[DATA_FACEBOX].append(facebox)
|
hass.data[DATA_FACEBOX].append(facebox)
|
||||||
add_devices(entities)
|
add_devices(entities)
|
||||||
@ -129,33 +189,37 @@ def setup_platform(hass, config, add_devices, discovery_info=None):
|
|||||||
classifier.teach(name, file_path)
|
classifier.teach(name, file_path)
|
||||||
|
|
||||||
hass.services.register(
|
hass.services.register(
|
||||||
DOMAIN,
|
DOMAIN, SERVICE_TEACH_FACE, service_handle,
|
||||||
SERVICE_TEACH_FACE,
|
|
||||||
service_handle,
|
|
||||||
schema=SERVICE_TEACH_SCHEMA)
|
schema=SERVICE_TEACH_SCHEMA)
|
||||||
|
|
||||||
|
|
||||||
class FaceClassifyEntity(ImageProcessingFaceEntity):
|
class FaceClassifyEntity(ImageProcessingFaceEntity):
|
||||||
"""Perform a face classification."""
|
"""Perform a face classification."""
|
||||||
|
|
||||||
def __init__(self, ip, port, camera_entity, name=None):
|
def __init__(self, ip_address, port, username, password, hostname,
|
||||||
|
camera_entity, name=None):
|
||||||
"""Init with the API key and model id."""
|
"""Init with the API key and model id."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._url_check = "http://{}:{}/{}/check".format(ip, port, CLASSIFIER)
|
self._url_check = "http://{}:{}/{}/check".format(
|
||||||
self._url_teach = "http://{}:{}/{}/teach".format(ip, port, CLASSIFIER)
|
ip_address, port, CLASSIFIER)
|
||||||
|
self._url_teach = "http://{}:{}/{}/teach".format(
|
||||||
|
ip_address, port, CLASSIFIER)
|
||||||
|
self._username = username
|
||||||
|
self._password = password
|
||||||
|
self._hostname = hostname
|
||||||
self._camera = camera_entity
|
self._camera = camera_entity
|
||||||
if name:
|
if name:
|
||||||
self._name = name
|
self._name = name
|
||||||
else:
|
else:
|
||||||
camera_name = split_entity_id(camera_entity)[1]
|
camera_name = split_entity_id(camera_entity)[1]
|
||||||
self._name = "{} {}".format(
|
self._name = "{} {}".format(CLASSIFIER, camera_name)
|
||||||
CLASSIFIER, camera_name)
|
|
||||||
self._matched = {}
|
self._matched = {}
|
||||||
|
|
||||||
def process_image(self, image):
|
def process_image(self, image):
|
||||||
"""Process an image."""
|
"""Process an image."""
|
||||||
response = post_image(self._url_check, image)
|
response = post_image(
|
||||||
if response is not None:
|
self._url_check, image, self._username, self._password)
|
||||||
|
if response:
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
if response_json['success']:
|
if response_json['success']:
|
||||||
total_faces = response_json['facesCount']
|
total_faces = response_json['facesCount']
|
||||||
@ -173,34 +237,8 @@ class FaceClassifyEntity(ImageProcessingFaceEntity):
|
|||||||
if (not self.hass.config.is_allowed_path(file_path)
|
if (not self.hass.config.is_allowed_path(file_path)
|
||||||
or not valid_file_path(file_path)):
|
or not valid_file_path(file_path)):
|
||||||
return
|
return
|
||||||
with open(file_path, 'rb') as open_file:
|
teach_file(
|
||||||
response = requests.post(
|
self._url_teach, name, file_path, self._username, self._password)
|
||||||
self._url_teach,
|
|
||||||
data={ATTR_NAME: name, 'id': file_path},
|
|
||||||
files={'file': open_file})
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
self.hass.bus.fire(
|
|
||||||
EVENT_CLASSIFIER_TEACH, {
|
|
||||||
ATTR_CLASSIFIER: CLASSIFIER,
|
|
||||||
ATTR_NAME: name,
|
|
||||||
FILE_PATH: file_path,
|
|
||||||
'success': True,
|
|
||||||
'message': None
|
|
||||||
})
|
|
||||||
|
|
||||||
elif response.status_code == 400:
|
|
||||||
_LOGGER.warning(
|
|
||||||
"%s teaching of file %s failed with message:%s",
|
|
||||||
CLASSIFIER, file_path, response.text)
|
|
||||||
self.hass.bus.fire(
|
|
||||||
EVENT_CLASSIFIER_TEACH, {
|
|
||||||
ATTR_CLASSIFIER: CLASSIFIER,
|
|
||||||
ATTR_NAME: name,
|
|
||||||
FILE_PATH: file_path,
|
|
||||||
'success': False,
|
|
||||||
'message': response.text
|
|
||||||
})
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def camera_entity(self):
|
def camera_entity(self):
|
||||||
@ -218,4 +256,5 @@ class FaceClassifyEntity(ImageProcessingFaceEntity):
|
|||||||
return {
|
return {
|
||||||
'matched_faces': self._matched,
|
'matched_faces': self._matched,
|
||||||
'total_matched_faces': len(self._matched),
|
'total_matched_faces': len(self._matched),
|
||||||
|
'hostname': self._hostname
|
||||||
}
|
}
|
||||||
|
@ -7,19 +7,19 @@ import requests_mock
|
|||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_ENTITY_ID, ATTR_NAME, CONF_FRIENDLY_NAME,
|
ATTR_ENTITY_ID, ATTR_NAME, CONF_FRIENDLY_NAME, CONF_PASSWORD,
|
||||||
CONF_IP_ADDRESS, CONF_PORT, STATE_UNKNOWN)
|
CONF_USERNAME, CONF_IP_ADDRESS, CONF_PORT,
|
||||||
|
HTTP_BAD_REQUEST, HTTP_OK, HTTP_UNAUTHORIZED, STATE_UNKNOWN)
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
import homeassistant.components.image_processing as ip
|
import homeassistant.components.image_processing as ip
|
||||||
import homeassistant.components.image_processing.facebox as fb
|
import homeassistant.components.image_processing.facebox as fb
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name
|
|
||||||
|
|
||||||
MOCK_IP = '192.168.0.1'
|
MOCK_IP = '192.168.0.1'
|
||||||
MOCK_PORT = '8080'
|
MOCK_PORT = '8080'
|
||||||
|
|
||||||
# Mock data returned by the facebox API.
|
# Mock data returned by the facebox API.
|
||||||
MOCK_ERROR = "No face found"
|
MOCK_BOX_ID = 'b893cc4f7fd6'
|
||||||
|
MOCK_ERROR_NO_FACE = "No face found"
|
||||||
MOCK_FACE = {'confidence': 0.5812028911604818,
|
MOCK_FACE = {'confidence': 0.5812028911604818,
|
||||||
'id': 'john.jpg',
|
'id': 'john.jpg',
|
||||||
'matched': True,
|
'matched': True,
|
||||||
@ -28,14 +28,21 @@ MOCK_FACE = {'confidence': 0.5812028911604818,
|
|||||||
|
|
||||||
MOCK_FILE_PATH = '/images/mock.jpg'
|
MOCK_FILE_PATH = '/images/mock.jpg'
|
||||||
|
|
||||||
|
MOCK_HEALTH = {'success': True,
|
||||||
|
'hostname': 'b893cc4f7fd6',
|
||||||
|
'metadata': {'boxname': 'facebox', 'build': 'development'},
|
||||||
|
'errors': []}
|
||||||
|
|
||||||
MOCK_JSON = {"facesCount": 1,
|
MOCK_JSON = {"facesCount": 1,
|
||||||
"success": True,
|
"success": True,
|
||||||
"faces": [MOCK_FACE]}
|
"faces": [MOCK_FACE]}
|
||||||
|
|
||||||
MOCK_NAME = 'mock_name'
|
MOCK_NAME = 'mock_name'
|
||||||
|
MOCK_USERNAME = 'mock_username'
|
||||||
|
MOCK_PASSWORD = 'mock_password'
|
||||||
|
|
||||||
# Faces data after parsing.
|
# Faces data after parsing.
|
||||||
PARSED_FACES = [{ATTR_NAME: 'John Lennon',
|
PARSED_FACES = [{fb.FACEBOX_NAME: 'John Lennon',
|
||||||
fb.ATTR_IMAGE_ID: 'john.jpg',
|
fb.ATTR_IMAGE_ID: 'john.jpg',
|
||||||
fb.ATTR_CONFIDENCE: 58.12,
|
fb.ATTR_CONFIDENCE: 58.12,
|
||||||
fb.ATTR_MATCHED: True,
|
fb.ATTR_MATCHED: True,
|
||||||
@ -62,6 +69,15 @@ VALID_CONFIG = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_healthybox():
|
||||||
|
"""Mock fb.check_box_health."""
|
||||||
|
check_box_health = 'homeassistant.components.image_processing.' \
|
||||||
|
'facebox.check_box_health'
|
||||||
|
with patch(check_box_health, return_value=MOCK_BOX_ID) as _mock_healthybox:
|
||||||
|
yield _mock_healthybox
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_isfile():
|
def mock_isfile():
|
||||||
"""Mock os.path.isfile."""
|
"""Mock os.path.isfile."""
|
||||||
@ -70,6 +86,14 @@ def mock_isfile():
|
|||||||
yield _mock_isfile
|
yield _mock_isfile
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_image():
|
||||||
|
"""Return a mock camera image."""
|
||||||
|
with patch('homeassistant.components.camera.demo.DemoCamera.camera_image',
|
||||||
|
return_value=b'Test') as image:
|
||||||
|
yield image
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_open_file():
|
def mock_open_file():
|
||||||
"""Mock open."""
|
"""Mock open."""
|
||||||
@ -79,6 +103,22 @@ def mock_open_file():
|
|||||||
yield _mock_open
|
yield _mock_open
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_box_health(caplog):
|
||||||
|
"""Test check box health."""
|
||||||
|
with requests_mock.Mocker() as mock_req:
|
||||||
|
url = "http://{}:{}/healthz".format(MOCK_IP, MOCK_PORT)
|
||||||
|
mock_req.get(url, status_code=HTTP_OK, json=MOCK_HEALTH)
|
||||||
|
assert fb.check_box_health(url, 'user', 'pass') == MOCK_BOX_ID
|
||||||
|
|
||||||
|
mock_req.get(url, status_code=HTTP_UNAUTHORIZED)
|
||||||
|
assert fb.check_box_health(url, None, None) is None
|
||||||
|
assert "AuthenticationError on facebox" in caplog.text
|
||||||
|
|
||||||
|
mock_req.get(url, exc=requests.exceptions.ConnectTimeout)
|
||||||
|
fb.check_box_health(url, None, None)
|
||||||
|
assert "ConnectionError: Is facebox running?" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
def test_encode_image():
|
def test_encode_image():
|
||||||
"""Test that binary data is encoded correctly."""
|
"""Test that binary data is encoded correctly."""
|
||||||
assert fb.encode_image(b'test') == 'dGVzdA=='
|
assert fb.encode_image(b'test') == 'dGVzdA=='
|
||||||
@ -100,22 +140,24 @@ def test_valid_file_path():
|
|||||||
assert not fb.valid_file_path('test_path')
|
assert not fb.valid_file_path('test_path')
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
async def test_setup_platform(hass, mock_healthybox):
|
||||||
def mock_image():
|
|
||||||
"""Return a mock camera image."""
|
|
||||||
with patch('homeassistant.components.camera.demo.DemoCamera.camera_image',
|
|
||||||
return_value=b'Test') as image:
|
|
||||||
yield image
|
|
||||||
|
|
||||||
|
|
||||||
async def test_setup_platform(hass):
|
|
||||||
"""Setup platform with one entity."""
|
"""Setup platform with one entity."""
|
||||||
await async_setup_component(hass, ip.DOMAIN, VALID_CONFIG)
|
await async_setup_component(hass, ip.DOMAIN, VALID_CONFIG)
|
||||||
assert hass.states.get(VALID_ENTITY_ID)
|
assert hass.states.get(VALID_ENTITY_ID)
|
||||||
|
|
||||||
|
|
||||||
async def test_process_image(hass, mock_image):
|
async def test_setup_platform_with_auth(hass, mock_healthybox):
|
||||||
"""Test processing of an image."""
|
"""Setup platform with one entity and auth."""
|
||||||
|
valid_config_auth = VALID_CONFIG.copy()
|
||||||
|
valid_config_auth[ip.DOMAIN][CONF_USERNAME] = MOCK_USERNAME
|
||||||
|
valid_config_auth[ip.DOMAIN][CONF_PASSWORD] = MOCK_PASSWORD
|
||||||
|
|
||||||
|
await async_setup_component(hass, ip.DOMAIN, valid_config_auth)
|
||||||
|
assert hass.states.get(VALID_ENTITY_ID)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_process_image(hass, mock_healthybox, mock_image):
|
||||||
|
"""Test successful processing of an image."""
|
||||||
await async_setup_component(hass, ip.DOMAIN, VALID_CONFIG)
|
await async_setup_component(hass, ip.DOMAIN, VALID_CONFIG)
|
||||||
assert hass.states.get(VALID_ENTITY_ID)
|
assert hass.states.get(VALID_ENTITY_ID)
|
||||||
|
|
||||||
@ -157,11 +199,12 @@ async def test_process_image(hass, mock_image):
|
|||||||
PARSED_FACES[0][fb.ATTR_BOUNDING_BOX])
|
PARSED_FACES[0][fb.ATTR_BOUNDING_BOX])
|
||||||
|
|
||||||
|
|
||||||
async def test_connection_error(hass, mock_image):
|
async def test_process_image_errors(hass, mock_healthybox, mock_image, caplog):
|
||||||
"""Test connection error."""
|
"""Test process_image errors."""
|
||||||
await async_setup_component(hass, ip.DOMAIN, VALID_CONFIG)
|
await async_setup_component(hass, ip.DOMAIN, VALID_CONFIG)
|
||||||
assert hass.states.get(VALID_ENTITY_ID)
|
assert hass.states.get(VALID_ENTITY_ID)
|
||||||
|
|
||||||
|
# Test connection error.
|
||||||
with requests_mock.Mocker() as mock_req:
|
with requests_mock.Mocker() as mock_req:
|
||||||
url = "http://{}:{}/facebox/check".format(MOCK_IP, MOCK_PORT)
|
url = "http://{}:{}/facebox/check".format(MOCK_IP, MOCK_PORT)
|
||||||
mock_req.register_uri(
|
mock_req.register_uri(
|
||||||
@ -171,34 +214,40 @@ async def test_connection_error(hass, mock_image):
|
|||||||
ip.SERVICE_SCAN,
|
ip.SERVICE_SCAN,
|
||||||
service_data=data)
|
service_data=data)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
assert "ConnectionError: Is facebox running?" in caplog.text
|
||||||
|
|
||||||
state = hass.states.get(VALID_ENTITY_ID)
|
state = hass.states.get(VALID_ENTITY_ID)
|
||||||
assert state.state == STATE_UNKNOWN
|
assert state.state == STATE_UNKNOWN
|
||||||
assert state.attributes.get('faces') == []
|
assert state.attributes.get('faces') == []
|
||||||
assert state.attributes.get('matched_faces') == {}
|
assert state.attributes.get('matched_faces') == {}
|
||||||
|
|
||||||
|
# Now test with bad auth.
|
||||||
|
with requests_mock.Mocker() as mock_req:
|
||||||
|
url = "http://{}:{}/facebox/check".format(MOCK_IP, MOCK_PORT)
|
||||||
|
mock_req.register_uri(
|
||||||
|
'POST', url, status_code=HTTP_UNAUTHORIZED)
|
||||||
|
data = {ATTR_ENTITY_ID: VALID_ENTITY_ID}
|
||||||
|
await hass.services.async_call(ip.DOMAIN,
|
||||||
|
ip.SERVICE_SCAN,
|
||||||
|
service_data=data)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert "AuthenticationError on facebox" in caplog.text
|
||||||
|
|
||||||
async def test_teach_service(hass, mock_image, mock_isfile, mock_open_file):
|
|
||||||
|
async def test_teach_service(
|
||||||
|
hass, mock_healthybox, mock_image,
|
||||||
|
mock_isfile, mock_open_file, caplog):
|
||||||
"""Test teaching of facebox."""
|
"""Test teaching of facebox."""
|
||||||
await async_setup_component(hass, ip.DOMAIN, VALID_CONFIG)
|
await async_setup_component(hass, ip.DOMAIN, VALID_CONFIG)
|
||||||
assert hass.states.get(VALID_ENTITY_ID)
|
assert hass.states.get(VALID_ENTITY_ID)
|
||||||
|
|
||||||
teach_events = []
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def mock_teach_event(event):
|
|
||||||
"""Mock event."""
|
|
||||||
teach_events.append(event)
|
|
||||||
|
|
||||||
hass.bus.async_listen(
|
|
||||||
'image_processing.teach_classifier', mock_teach_event)
|
|
||||||
|
|
||||||
# Patch out 'is_allowed_path' as the mock files aren't allowed
|
# Patch out 'is_allowed_path' as the mock files aren't allowed
|
||||||
hass.config.is_allowed_path = Mock(return_value=True)
|
hass.config.is_allowed_path = Mock(return_value=True)
|
||||||
|
|
||||||
|
# Test successful teach.
|
||||||
with requests_mock.Mocker() as mock_req:
|
with requests_mock.Mocker() as mock_req:
|
||||||
url = "http://{}:{}/facebox/teach".format(MOCK_IP, MOCK_PORT)
|
url = "http://{}:{}/facebox/teach".format(MOCK_IP, MOCK_PORT)
|
||||||
mock_req.post(url, status_code=200)
|
mock_req.post(url, status_code=HTTP_OK)
|
||||||
data = {ATTR_ENTITY_ID: VALID_ENTITY_ID,
|
data = {ATTR_ENTITY_ID: VALID_ENTITY_ID,
|
||||||
ATTR_NAME: MOCK_NAME,
|
ATTR_NAME: MOCK_NAME,
|
||||||
fb.FILE_PATH: MOCK_FILE_PATH}
|
fb.FILE_PATH: MOCK_FILE_PATH}
|
||||||
@ -206,17 +255,10 @@ async def test_teach_service(hass, mock_image, mock_isfile, mock_open_file):
|
|||||||
ip.DOMAIN, fb.SERVICE_TEACH_FACE, service_data=data)
|
ip.DOMAIN, fb.SERVICE_TEACH_FACE, service_data=data)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert len(teach_events) == 1
|
# Now test with bad auth.
|
||||||
assert teach_events[0].data[fb.ATTR_CLASSIFIER] == fb.CLASSIFIER
|
|
||||||
assert teach_events[0].data[ATTR_NAME] == MOCK_NAME
|
|
||||||
assert teach_events[0].data[fb.FILE_PATH] == MOCK_FILE_PATH
|
|
||||||
assert teach_events[0].data['success']
|
|
||||||
assert not teach_events[0].data['message']
|
|
||||||
|
|
||||||
# Now test the failed teaching.
|
|
||||||
with requests_mock.Mocker() as mock_req:
|
with requests_mock.Mocker() as mock_req:
|
||||||
url = "http://{}:{}/facebox/teach".format(MOCK_IP, MOCK_PORT)
|
url = "http://{}:{}/facebox/teach".format(MOCK_IP, MOCK_PORT)
|
||||||
mock_req.post(url, status_code=400, text=MOCK_ERROR)
|
mock_req.post(url, status_code=HTTP_UNAUTHORIZED)
|
||||||
data = {ATTR_ENTITY_ID: VALID_ENTITY_ID,
|
data = {ATTR_ENTITY_ID: VALID_ENTITY_ID,
|
||||||
ATTR_NAME: MOCK_NAME,
|
ATTR_NAME: MOCK_NAME,
|
||||||
fb.FILE_PATH: MOCK_FILE_PATH}
|
fb.FILE_PATH: MOCK_FILE_PATH}
|
||||||
@ -224,16 +266,37 @@ async def test_teach_service(hass, mock_image, mock_isfile, mock_open_file):
|
|||||||
fb.SERVICE_TEACH_FACE,
|
fb.SERVICE_TEACH_FACE,
|
||||||
service_data=data)
|
service_data=data)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
assert "AuthenticationError on facebox" in caplog.text
|
||||||
|
|
||||||
assert len(teach_events) == 2
|
# Now test the failed teaching.
|
||||||
assert teach_events[1].data[fb.ATTR_CLASSIFIER] == fb.CLASSIFIER
|
with requests_mock.Mocker() as mock_req:
|
||||||
assert teach_events[1].data[ATTR_NAME] == MOCK_NAME
|
url = "http://{}:{}/facebox/teach".format(MOCK_IP, MOCK_PORT)
|
||||||
assert teach_events[1].data[fb.FILE_PATH] == MOCK_FILE_PATH
|
mock_req.post(url, status_code=HTTP_BAD_REQUEST,
|
||||||
assert not teach_events[1].data['success']
|
text=MOCK_ERROR_NO_FACE)
|
||||||
assert teach_events[1].data['message'] == MOCK_ERROR
|
data = {ATTR_ENTITY_ID: VALID_ENTITY_ID,
|
||||||
|
ATTR_NAME: MOCK_NAME,
|
||||||
|
fb.FILE_PATH: MOCK_FILE_PATH}
|
||||||
|
await hass.services.async_call(ip.DOMAIN,
|
||||||
|
fb.SERVICE_TEACH_FACE,
|
||||||
|
service_data=data)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert MOCK_ERROR_NO_FACE in caplog.text
|
||||||
|
|
||||||
|
# Now test connection error.
|
||||||
|
with requests_mock.Mocker() as mock_req:
|
||||||
|
url = "http://{}:{}/facebox/teach".format(MOCK_IP, MOCK_PORT)
|
||||||
|
mock_req.post(url, exc=requests.exceptions.ConnectTimeout)
|
||||||
|
data = {ATTR_ENTITY_ID: VALID_ENTITY_ID,
|
||||||
|
ATTR_NAME: MOCK_NAME,
|
||||||
|
fb.FILE_PATH: MOCK_FILE_PATH}
|
||||||
|
await hass.services.async_call(ip.DOMAIN,
|
||||||
|
fb.SERVICE_TEACH_FACE,
|
||||||
|
service_data=data)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert "ConnectionError: Is facebox running?" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
async def test_setup_platform_with_name(hass):
|
async def test_setup_platform_with_name(hass, mock_healthybox):
|
||||||
"""Setup platform with one entity and a name."""
|
"""Setup platform with one entity and a name."""
|
||||||
named_entity_id = 'image_processing.{}'.format(MOCK_NAME)
|
named_entity_id = 'image_processing.{}'.format(MOCK_NAME)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user