Overwork Services/Discovery (#725)

* Update homeassistant.py

* Update validate.py

* Update exceptions.py

* Update services.py

* Update discovery.py

* fix gitignore

* Fix handling for discovery

* use object in ref

* lock down discovery API

* fix api

* Design

* Fix API

* fix lint

* fix

* Fix security layer

* add provide layer

* fix access

* change rating

* fix rights

* Fix API error handling

* raise error

* fix rights

* api

* fix handling

* fix

* debug

* debug json

* Fix validator

* fix error

* new url

* fix schema
This commit is contained in:
Pascal Vizeli 2018-09-29 19:49:08 +02:00 committed by GitHub
parent 4ef8c9d633
commit e5451973bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 263 additions and 202 deletions

3
.gitignore vendored
View File

@ -90,3 +90,6 @@ ENV/
# pylint
.pylint.d/
# VS Code
.vscode/

25
API.md
View File

@ -499,8 +499,8 @@ Get all available addons.
"audio": "bool",
"audio_input": "null|0,0",
"audio_output": "null|0,0",
"services": "null|['mqtt']",
"discovery": "null|['component/platform']"
"services_role": "['service:access']",
"discovery": "['service']"
}
```
@ -576,12 +576,13 @@ Write data to add-on stdin
### Service discovery
- GET `/services/discovery`
- GET `/discovery`
```json
{
"discovery": [
{
"provider": "name",
"addon": "slug",
"service": "name",
"uuid": "uuid",
"component": "component",
"platform": "null|platform",
@ -591,10 +592,11 @@ Write data to add-on stdin
}
```
- GET `/services/discovery/{UUID}`
- GET `/discovery/{UUID}`
```json
{
"provider": "name",
"addon": "slug",
"service": "name",
"uuid": "uuid",
"component": "component",
"platform": "null|platform",
@ -602,9 +604,10 @@ Write data to add-on stdin
}
```
- POST `/services/discovery`
- POST `/discovery`
```json
{
"service": "name",
"component": "component",
"platform": "null|platform",
"config": {}
@ -618,7 +621,7 @@ return:
}
```
- DEL `/services/discovery/{UUID}`
- DEL `/discovery/{UUID}`
- GET `/services`
```json
@ -627,7 +630,7 @@ return:
{
"slug": "name",
"available": "bool",
"provider": "null|name|list"
"providers": "list"
}
]
}
@ -635,12 +638,10 @@ return:
#### MQTT
This service performs an auto discovery to Home-Assistant.
- GET `/services/mqtt`
```json
{
"provider": "name",
"addon": "name",
"host": "xy",
"port": "8883",
"ssl": "bool",

View File

@ -239,24 +239,23 @@ class Addon(CoreSysAttributes):
return self._mesh.get(ATTR_STARTUP)
@property
def services(self):
def services_role(self):
"""Return dict of services with rights."""
raw_services = self._mesh.get(ATTR_SERVICES)
if not raw_services:
return None
return {}
formated_services = {}
services = {}
for data in raw_services:
service = RE_SERVICE.match(data)
formated_services[service.group('service')] = \
service.group('rights') or 'ro'
services[service.group('service')] = service.group('rights')
return formated_services
return services
@property
def discovery(self):
"""Return list of discoverable components/platforms."""
return self._mesh.get(ATTR_DISCOVERY)
return self._mesh.get(ATTR_DISCOVERY, [])
@property
def ports(self):

View File

@ -4,6 +4,7 @@ from .utils import get_hash_from_repository
from ..const import (
REPOSITORY_CORE, REPOSITORY_LOCAL, ATTR_NAME, ATTR_URL, ATTR_MAINTAINER)
from ..coresys import CoreSysAttributes
from ..exceptions import APIError
UNKNOWN = 'unknown'
@ -67,6 +68,6 @@ class Repository(CoreSysAttributes):
def remove(self):
"""Remove add-on repository."""
if self._id in (REPOSITORY_CORE, REPOSITORY_LOCAL):
raise RuntimeError("Can't remove built-in repositories!")
raise APIError("Can't remove built-in repositories!")
self.git.remove()

View File

@ -28,10 +28,6 @@ def rating_security(addon):
elif addon.apparmor == SECURITY_PROFILE:
rating += 1
# API Access
if addon.access_hassio_api or addon.access_homeassistant_api:
rating += -1
# Privileged options
if addon.privileged in (PRIVILEGED_NET_ADMIN, PRIVILEGED_SYS_ADMIN,
PRIVILEGED_SYS_RAWIO, PRIVILEGED_SYS_PTRACE):

View File

@ -26,13 +26,13 @@ from ..const import (
PRIVILEGED_SYS_RESOURCE, PRIVILEGED_SYS_PTRACE,
ROLE_DEFAULT, ROLE_HOMEASSISTANT, ROLE_MANAGER, ROLE_ADMIN)
from ..validate import NETWORK_PORT, DOCKER_PORTS, ALSA_DEVICE
from ..services.validate import DISCOVERY_SERVICES
_LOGGER = logging.getLogger(__name__)
RE_VOLUME = re.compile(r"^(config|ssl|addons|backup|share)(?::(rw|:ro))?$")
RE_SERVICE = re.compile(r"^(?P<service>mqtt)(?::(?P<rights>rw|:ro))?$")
RE_DISCOVERY = re.compile(r"^(?P<component>\w*)(?:/(?P<platform>\w*>))?$")
RE_SERVICE = re.compile(r"^(?P<service>mqtt):(?P<rights>provide|want|need)$")
V_STR = 'str'
V_INT = 'int'
@ -143,7 +143,7 @@ SCHEMA_ADDON_CONFIG = vol.Schema({
vol.Optional(ATTR_LEGACY, default=False): vol.Boolean(),
vol.Optional(ATTR_DOCKER_API, default=False): vol.Boolean(),
vol.Optional(ATTR_SERVICES): [vol.Match(RE_SERVICE)],
vol.Optional(ATTR_DISCOVERY): [vol.Match(RE_DISCOVERY)],
vol.Optional(ATTR_DISCOVERY): [vol.In(DISCOVERY_SERVICES)],
vol.Required(ATTR_OPTIONS): dict,
vol.Required(ATTR_SCHEMA): vol.Any(vol.Schema({
vol.Coerce(str): vol.Any(SCHEMA_ELEMENT, [

View File

@ -211,11 +211,11 @@ class RestAPI(CoreSysAttributes):
api_discovery.coresys = self.coresys
self.webapp.add_routes([
web.get('/services/discovery', api_discovery.list),
web.get('/services/discovery/{uuid}', api_discovery.get_discovery),
web.delete('/services/discovery/{uuid}',
web.get('/discovery', api_discovery.list),
web.get('/discovery/{uuid}', api_discovery.get_discovery),
web.delete('/discovery/{uuid}',
api_discovery.del_discovery),
web.post('/services/discovery', api_discovery.set_discovery),
web.post('/discovery', api_discovery.set_discovery),
])
def _register_panel(self):

View File

@ -55,7 +55,7 @@ class APIAddons(CoreSysAttributes):
# Lookup itself
if addon_slug == 'self':
addon_slug = request.get(REQUEST_FROM)
return request.get(REQUEST_FROM)
addon = self.sys_addons.get(addon_slug)
if not addon:
@ -66,14 +66,6 @@ class APIAddons(CoreSysAttributes):
return addon
@staticmethod
def _pretty_devices(addon):
"""Return a simplified device list."""
dev_list = addon.devices
if not dev_list:
return None
return [row.split(':')[0] for row in dev_list]
@api_process
async def list(self, request):
"""Return all add-ons or repositories."""
@ -148,7 +140,7 @@ class APIAddons(CoreSysAttributes):
ATTR_PRIVILEGED: addon.privileged,
ATTR_FULL_ACCESS: addon.with_full_access,
ATTR_APPARMOR: addon.apparmor,
ATTR_DEVICES: self._pretty_devices(addon),
ATTR_DEVICES: _pretty_devices(addon),
ATTR_ICON: addon.with_icon,
ATTR_LOGO: addon.with_logo,
ATTR_CHANGELOG: addon.with_changelog,
@ -163,7 +155,7 @@ class APIAddons(CoreSysAttributes):
ATTR_AUDIO: addon.with_audio,
ATTR_AUDIO_INPUT: addon.audio_input,
ATTR_AUDIO_OUTPUT: addon.audio_output,
ATTR_SERVICES: addon.services,
ATTR_SERVICES: _pretty_services(addon),
ATTR_DISCOVERY: addon.discovery,
}
@ -328,3 +320,19 @@ class APIAddons(CoreSysAttributes):
data = await request.read()
return await asyncio.shield(addon.write_stdin(data))
def _pretty_devices(addon):
"""Return a simplified device list."""
dev_list = addon.devices
if not dev_list:
return None
return [row.split(':')[0] for row in dev_list]
def _pretty_services(addon):
"""Return a simplified services role list."""
services = []
for name, access in addon.services_role.items():
services.append(f"{name}:{access}")
return services

View File

@ -3,15 +3,18 @@ import voluptuous as vol
from .utils import api_process, api_validate
from ..const import (
ATTR_PROVIDER, ATTR_UUID, ATTR_COMPONENT, ATTR_PLATFORM, ATTR_CONFIG,
ATTR_DISCOVERY, REQUEST_FROM)
ATTR_ADDON, ATTR_UUID, ATTR_COMPONENT, ATTR_PLATFORM, ATTR_CONFIG,
ATTR_DISCOVERY, ATTR_SERVICE, REQUEST_FROM)
from ..coresys import CoreSysAttributes
from ..exceptions import APIError, APIForbidden
from ..services.validate import SERVICE_ALL
SCHEMA_DISCOVERY = vol.Schema({
vol.Required(ATTR_SERVICE): vol.In(SERVICE_ALL),
vol.Required(ATTR_COMPONENT): vol.Coerce(str),
vol.Optional(ATTR_PLATFORM): vol.Any(None, vol.Coerce(str)),
vol.Optional(ATTR_CONFIG): vol.Any(None, dict),
vol.Optional(ATTR_PLATFORM): vol.Maybe(vol.Coerce(str)),
vol.Optional(ATTR_CONFIG): vol.Maybe(dict),
})
@ -22,16 +25,24 @@ class APIDiscovery(CoreSysAttributes):
"""Extract discovery message from URL."""
message = self.sys_discovery.get(request.match_info.get('uuid'))
if not message:
raise RuntimeError("Discovery message not found")
raise APIError("Discovery message not found")
return message
def _check_permission_ha(self, request):
"""Check permission for API call / Home Assistant."""
if request[REQUEST_FROM] != self.sys_homeassistant:
raise APIForbidden("Only HomeAssistant can use this API!")
@api_process
async def list(self, request):
"""Show register services."""
self._check_permission_ha(request)
discovery = []
for message in self.sys_discovery.list_messages:
discovery.append({
ATTR_PROVIDER: message.provider,
ATTR_ADDON: message.addon,
ATTR_SERVICE: message.service,
ATTR_UUID: message.uuid,
ATTR_COMPONENT: message.component,
ATTR_PLATFORM: message.platform,
@ -44,8 +55,14 @@ class APIDiscovery(CoreSysAttributes):
async def set_discovery(self, request):
"""Write data into a discovery pipeline."""
body = await api_validate(SCHEMA_DISCOVERY, request)
message = self.sys_discovery.send(
provider=request[REQUEST_FROM], **body)
addon = request[REQUEST_FROM]
# Access?
if body[ATTR_SERVICE] not in addon.discovery:
raise APIForbidden(f"Can't use discovery!")
# Process discovery message
message = self.sys_discovery.send(addon, **body)
return {ATTR_UUID: message.uuid}
@ -54,8 +71,12 @@ class APIDiscovery(CoreSysAttributes):
"""Read data into a discovery message."""
message = self._extract_message(request)
# HomeAssistant?
self._check_permission_ha(request)
return {
ATTR_PROVIDER: message.provider,
ATTR_ADDON: message.addon,
ATTR_SERVICE: message.service,
ATTR_UUID: message.uuid,
ATTR_COMPONENT: message.component,
ATTR_PLATFORM: message.platform,
@ -66,6 +87,11 @@ class APIDiscovery(CoreSysAttributes):
async def del_discovery(self, request):
"""Delete data into a discovery message."""
message = self._extract_message(request)
addon = request[REQUEST_FROM]
# Permission
if message.addon != addon.slug:
raise APIForbidden(f"Can't remove discovery message")
self.sys_discovery.remove(message)
return True

View File

@ -13,6 +13,7 @@ from ..const import (
ATTR_REFRESH_TOKEN, CONTENT_TYPE_BINARY)
from ..coresys import CoreSysAttributes
from ..validate import NETWORK_PORT, DOCKER_IMAGE
from ..exceptions import APIError
_LOGGER = logging.getLogger(__name__)
@ -94,7 +95,7 @@ class APIHomeAssistant(CoreSysAttributes):
"""Return resource information."""
stats = await self.sys_homeassistant.stats()
if not stats:
raise RuntimeError("No stats available")
raise APIError("No stats available")
return {
ATTR_CPU_PERCENT: stats.cpu_percent,
@ -139,6 +140,6 @@ class APIHomeAssistant(CoreSysAttributes):
"""Check configuration of Home Assistant."""
result = await self.sys_homeassistant.check_config()
if not result.valid:
raise RuntimeError(result.log)
raise APIError(result.log)
return True

View File

@ -26,7 +26,6 @@ NO_SECURITY_CHECK = re.compile(
r"|/homeassistant/api/.*"
r"|/homeassistant/websocket"
r"|/supervisor/ping"
r"|/services.*"
r")$"
)
@ -35,6 +34,8 @@ ADDONS_API_BYPASS = re.compile(
r"^(?:"
r"|/addons/self/(?!security)[^/]+"
r"|/version"
r"|/services.*"
r"|/discovery.*"
r")$"
)
@ -58,8 +59,7 @@ ADDONS_ROLE_ACCESS = {
r"|/hardware/.+"
r"|/hassos/.+"
r"|/supervisor/.+"
r"|/addons/[^/]+/(?!security|options).+"
r"|/addons(?:/self/(?!security).+)?"
r"|/addons/[^/]+/(?!security).+"
r"|/snapshots.*"
r")$"
),
@ -102,12 +102,12 @@ class SecurityMiddleware(CoreSysAttributes):
if hassio_token in (self.sys_homeassistant.uuid,
self.sys_homeassistant.hassio_token):
_LOGGER.debug("%s access from Home Assistant", request.path)
request_from = 'homeassistant'
request_from = self.sys_homeassistant
# Host
if hassio_token == self.sys_machine_id:
_LOGGER.debug("%s access from Host", request.path)
request_from = 'host'
request_from = self.sys_host
# Add-on
addon = None
@ -117,12 +117,12 @@ class SecurityMiddleware(CoreSysAttributes):
# Check Add-on API access
if addon and ADDONS_API_BYPASS.match(request.path):
_LOGGER.debug("Passthrough %s from %s", request.path, addon.slug)
request_from = addon.slug
request_from = addon
elif addon and addon.access_hassio_api:
# Check Role
if ADDONS_ROLE_ACCESS[addon.hassio_role].match(request.path):
_LOGGER.info("%s access from %s", request.path, addon.slug)
request_from = addon.slug
request_from = addon
else:
_LOGGER.warning("%s no role for %s", request.path, addon.slug)

View File

@ -2,8 +2,10 @@
from .utils import api_process, api_validate
from ..const import (
ATTR_AVAILABLE, ATTR_PROVIDER, ATTR_SLUG, ATTR_SERVICES, REQUEST_FROM)
ATTR_AVAILABLE, ATTR_PROVIDERS, ATTR_SLUG, ATTR_SERVICES, REQUEST_FROM,
PROVIDE_SERVICE)
from ..coresys import CoreSysAttributes
from ..exceptions import APIError, APIForbidden
class APIServices(CoreSysAttributes):
@ -13,7 +15,7 @@ class APIServices(CoreSysAttributes):
"""Return service, throw an exception if it doesn't exist."""
service = self.sys_services.get(request.match_info.get('service'))
if not service:
raise RuntimeError("Service does not exist")
raise APIError("Service does not exist")
return service
@ -25,7 +27,7 @@ class APIServices(CoreSysAttributes):
services.append({
ATTR_SLUG: service.slug,
ATTR_AVAILABLE: service.enabled,
ATTR_PROVIDER: service.provider,
ATTR_PROVIDERS: service.providers,
})
return {ATTR_SERVICES: services}
@ -35,21 +37,39 @@ class APIServices(CoreSysAttributes):
"""Write data into a service."""
service = self._extract_service(request)
body = await api_validate(service.schema, request)
addon = request[REQUEST_FROM]
return service.set_service_data(request[REQUEST_FROM], body)
_check_access(request, service.slug)
service.set_service_data(addon, body)
@api_process
async def get_service(self, request):
"""Read data into a service."""
service = self._extract_service(request)
return {
ATTR_AVAILABLE: service.enabled,
service.slug: service.get_service_data(),
}
# Access
_check_access(request, service.slug)
if not service.enabled:
raise APIError("Service not enabled")
return service.get_service_data()
@api_process
async def del_service(self, request):
"""Delete data into a service."""
service = self._extract_service(request)
return service.del_service_data(request[REQUEST_FROM])
addon = request[REQUEST_FROM]
# Access
_check_access(request, service.slug, True)
service.del_service_data(addon)
def _check_access(request, service, provide=False):
"""Raise error if the rights are wrong."""
addon = request[REQUEST_FROM]
if not addon.services_role.get(service):
raise APIForbidden(f"No access to {service} service!")
if provide and addon.services_role.get(service) != PROVIDE_SERVICE:
raise APIForbidden(f"No access to write {service} service!")

View File

@ -14,6 +14,7 @@ from ..const import (
ATTR_HOMEASSISTANT, ATTR_VERSION, ATTR_SIZE, ATTR_FOLDERS, ATTR_TYPE,
ATTR_SNAPSHOTS, ATTR_PASSWORD, ATTR_PROTECTED, CONTENT_TYPE_TAR)
from ..coresys import CoreSysAttributes
from ..exceptions import APIError
_LOGGER = logging.getLogger(__name__)
@ -52,7 +53,7 @@ class APISnapshots(CoreSysAttributes):
"""Return snapshot, throw an exception if it doesn't exist."""
snapshot = self.sys_snapshots.get(request.match_info.get('snapshot'))
if not snapshot:
raise RuntimeError("Snapshot does not exist")
raise APIError("Snapshot does not exist")
return snapshot
@api_process

View File

@ -14,6 +14,7 @@ from ..const import (
ATTR_BLK_WRITE, CONTENT_TYPE_BINARY, ATTR_ICON)
from ..coresys import CoreSysAttributes
from ..validate import validate_timezone, WAIT_BOOT, REPOSITORIES, CHANNELS
from ..exceptions import APIError
_LOGGER = logging.getLogger(__name__)
@ -93,7 +94,7 @@ class APISupervisor(CoreSysAttributes):
"""Return resource information."""
stats = await self.sys_supervisor.stats()
if not stats:
raise RuntimeError("No stats available")
raise APIError("No stats available")
return {
ATTR_CPU_PERCENT: stats.cpu_percent,
@ -112,7 +113,7 @@ class APISupervisor(CoreSysAttributes):
version = body.get(ATTR_VERSION, self.sys_updater.version_hassio)
if version == self.sys_supervisor.version:
raise RuntimeError("Version {} is already in use".format(version))
raise APIError("Version {} is already in use".format(version))
return await asyncio.shield(
self.sys_supervisor.update(version))
@ -128,7 +129,7 @@ class APISupervisor(CoreSysAttributes):
for result in results:
if result.exception() is not None:
raise RuntimeError("Some reload task fails!")
raise APIError("Some reload task fails!")
return True

View File

@ -9,7 +9,7 @@ from voluptuous.humanize import humanize_error
from ..const import (
JSON_RESULT, JSON_DATA, JSON_MESSAGE, RESULT_OK, RESULT_ERROR,
CONTENT_TYPE_BINARY)
from ..exceptions import HassioError
from ..exceptions import HassioError, APIError, APIForbidden
_LOGGER = logging.getLogger(__name__)
@ -21,7 +21,7 @@ def json_loads(data):
try:
return json.loads(data)
except json.JSONDecodeError:
raise RuntimeError("Invalid json")
raise APIError("Invalid json")
def api_process(method):
@ -30,10 +30,10 @@ def api_process(method):
"""Return API information."""
try:
answer = await method(api, *args, **kwargs)
except HassioError:
return api_return_error()
except RuntimeError as err:
except (APIError, APIForbidden) as err:
return api_return_error(message=str(err))
except HassioError:
return api_return_error(message="Unknown Error, see logs")
if isinstance(answer, dict):
return api_return_ok(data=answer)
@ -55,7 +55,7 @@ def api_process_raw(content):
try:
msg_data = await method(api, *args, **kwargs)
msg_type = content
except RuntimeError as err:
except (APIError, APIForbidden) as err:
msg_data = str(err).encode()
msg_type = CONTENT_TYPE_BINARY
except HassioError:
@ -90,6 +90,6 @@ async def api_validate(schema, request):
try:
data = schema(data)
except vol.Invalid as ex:
raise RuntimeError(humanize_error(data, ex)) from None
raise APIError(humanize_error(data, ex)) from None
return data

View File

@ -74,6 +74,7 @@ ATTR_TYPE = 'type'
ATTR_SOURCE = 'source'
ATTR_FEATURES = 'features'
ATTR_ADDONS = 'addons'
ATTR_PROVIDERS = 'providers'
ATTR_VERSION = 'version'
ATTR_VERSION_LATEST = 'version_latest'
ATTR_AUTO_UART = 'auto_uart'
@ -107,8 +108,6 @@ ATTR_MAINTAINER = 'maintainer'
ATTR_PASSWORD = 'password'
ATTR_TOTP = 'totp'
ATTR_INITIALIZE = 'initialize'
ATTR_SESSION = 'session'
ATTR_SESSIONS = 'sessions'
ATTR_LOCATON = 'location'
ATTR_BUILD = 'build'
ATTR_DEVICES = 'devices'
@ -154,7 +153,7 @@ ATTR_MEMORY_LIMIT = 'memory_limit'
ATTR_MEMORY_USAGE = 'memory_usage'
ATTR_BLK_READ = 'blk_read'
ATTR_BLK_WRITE = 'blk_write'
ATTR_PROVIDER = 'provider'
ATTR_ADDON = 'addon'
ATTR_AVAILABLE = 'available'
ATTR_HOST = 'host'
ATTR_USERNAME = 'username'
@ -163,8 +162,8 @@ ATTR_DISCOVERY = 'discovery'
ATTR_PLATFORM = 'platform'
ATTR_COMPONENT = 'component'
ATTR_CONFIG = 'config'
ATTR_DISCOVERY_ID = 'discovery_id'
ATTR_SERVICES = 'services'
ATTR_SERVICE = 'service'
ATTR_DISCOVERY = 'discovery'
ATTR_PROTECTED = 'protected'
ATTR_CRYPTO = 'crypto'
@ -188,6 +187,9 @@ ATTR_HASSIO_ROLE = 'hassio_role'
ATTR_SUPERVISOR = 'supervisor'
SERVICE_MQTT = 'mqtt'
PROVIDE_SERVICE = 'provide'
NEED_SERVICE = 'need'
WANT_SERVICE = 'want'
STARTUP_INITIALIZE = 'initialize'
STARTUP_SYSTEM = 'system'

View File

@ -81,13 +81,25 @@ class HostAppArmorError(HostError):
# API
class APIError(HassioError):
class APIError(HassioError, RuntimeError):
"""API errors."""
pass
class APINotSupportedError(HassioNotSupportedError):
"""API not supported error."""
class APIForbidden(APIError):
"""API forbidden error."""
pass
# Service / Discovery
class DiscoveryError(HassioError):
"""Discovery Errors."""
pass
class ServicesError(HassioError):
"""Services Errors."""
pass

View File

@ -439,19 +439,6 @@ class HomeAssistant(JsonConfig, CoreSysAttributes):
_LOGGER.warning("Home Assistant API config mismatch: %d", err)
return False
async def send_event(self, event_type, event_data=None):
"""Send event to Home-Assistant."""
with suppress(HomeAssistantAPIError):
async with self.make_request(
'get', f'api/events/{event_type}'
) as resp:
if resp.status in (200, 201):
return
err = resp.status
_LOGGER.warning("Home Assistant event %s fails: %s", event_type, err)
return HomeAssistantError()
async def _block_till_run(self):
"""Block until Home-Assistant is booting up or startup timeout."""
start_time = time.monotonic()

View File

@ -1,14 +1,20 @@
"""Handle discover message for Home Assistant."""
import logging
from contextlib import suppress
from uuid import uuid4
from ..const import ATTR_UUID
import attr
import voluptuous as vol
from voluptuous.humanize import humanize_error
from .validate import DISCOVERY_SERVICES
from ..coresys import CoreSysAttributes
from ..exceptions import DiscoveryError, HomeAssistantAPIError
_LOGGER = logging.getLogger(__name__)
EVENT_DISCOVERY_ADD = 'hassio_discovery_add'
EVENT_DISCOVERY_DEL = 'hassio_discovery_del'
CMD_NEW = 'post'
CMD_DEL = 'delete'
class Discovery(CoreSysAttributes):
@ -32,7 +38,7 @@ class Discovery(CoreSysAttributes):
"""Write discovery message into data file."""
messages = []
for message in self.message_obj.values():
messages.append(message.raw())
messages.append(attr.asdict(message))
self._data.clear()
self._data.extend(messages)
@ -52,26 +58,31 @@ class Discovery(CoreSysAttributes):
"""Return list of available discovery messages."""
return self.message_obj.values()
def send(self, provider, component, platform=None, config=None):
def send(self, addon, service, component, platform, config):
"""Send a discovery message to Home Assistant."""
message = Message(provider, component, platform, config)
try:
DISCOVERY_SERVICES[service](config)
except vol.Invalid as err:
_LOGGER.error(
"Invalid discovery %s config", humanize_error(config, err))
raise DiscoveryError() from None
# Create message
message = Message(addon.slug, service, component, platform, config)
# Already exists?
for exists_message in self.message_obj:
if exists_message == message:
_LOGGER.warning("Found duplicate discovery message from %s",
provider)
return exists_message
for old_message in self.message_obj:
if old_message != message:
continue
_LOGGER.warning("Duplicate discovery message from %s", addon.slug)
return old_message
_LOGGER.info("Send discovery to Home Assistant %s/%s from %s",
component, platform, provider)
component, platform, addon.slug)
self.message_obj[message.uuid] = message
self.save()
# Send event to Home Assistant
self.sys_create_task(self.sys_homeassistant.send_event(
EVENT_DISCOVERY_ADD, {ATTR_UUID: message.uuid}))
self.sys_create_task(self._push_discovery(message.uuid, CMD_NEW))
return message
def remove(self, message):
@ -79,29 +90,31 @@ class Discovery(CoreSysAttributes):
self.message_obj.pop(message.uuid, None)
self.save()
# send event to Home-Assistant
self.sys_create_task(self.sys_homeassistant.send_event(
EVENT_DISCOVERY_DEL, {ATTR_UUID: message.uuid}))
_LOGGER.info("Delete discovery to Home Assistant %s/%s from %s",
message.component, message.platform, message.addon)
self.sys_create_task(self._push_discovery(message.uuid, CMD_DEL))
async def _push_discovery(self, uuid, command):
"""Send a discovery request."""
if not await self.sys_homeassistant.check_api_state():
_LOGGER.info("Discovery %s mesage ignore", uuid)
return
with suppress(HomeAssistantAPIError):
async with self.sys_homeassistant.make_request(
command, f"api/hassio_push/discovery/{uuid}"):
_LOGGER.info("Discovery %s message send", uuid)
return
_LOGGER.warning("Discovery %s message fail", uuid)
@attr.s
class Message:
"""Represent a single Discovery message."""
def __init__(self, provider, component, platform, config, uuid=None):
"""Initialize discovery message."""
self.provider = provider
self.component = component
self.platform = platform
self.config = config
self.uuid = uuid or uuid4().hex
def raw(self):
"""Return raw discovery message."""
return self.__dict__
def __eq__(self, other):
"""Compare with other message."""
for attribute in ('provider', 'component', 'platform', 'config'):
if getattr(self, attribute) != getattr(other, attribute):
return False
return True
addon = attr.ib()
service = attr.ib()
component = attr.ib()
platform = attr.ib()
config = attr.ib()
uuid = attr.ib(factory=lambda: uuid4().hex, cmp=False)

View File

@ -1,6 +1,7 @@
"""Interface for single service."""
from ..coresys import CoreSysAttributes
from ..const import PROVIDE_SERVICE
class ServiceInterface(CoreSysAttributes):
@ -26,9 +27,13 @@ class ServiceInterface(CoreSysAttributes):
return None
@property
def provider(self):
"""Return name of service provider."""
return None
def providers(self):
"""Return name of service providers addon."""
addons = []
for addon in self.sys_addons.list_installed:
if addon.services_role.get(self.slug) == PROVIDE_SERVICE:
addons.append(addon.slug)
return addons
@property
def enabled(self):
@ -45,10 +50,10 @@ class ServiceInterface(CoreSysAttributes):
return self._data
return None
def set_service_data(self, provider, data):
def set_service_data(self, addon, data):
"""Write the data into service object."""
raise NotImplementedError()
def del_service_data(self, provider):
def del_service_data(self, addon):
"""Remove the data from service object."""
raise NotImplementedError()

View File

@ -3,9 +3,8 @@ import logging
from .interface import ServiceInterface
from .validate import SCHEMA_SERVICE_MQTT
from ..const import (
ATTR_PROVIDER, SERVICE_MQTT, ATTR_HOST, ATTR_PORT, ATTR_USERNAME,
ATTR_PASSWORD, ATTR_PROTOCOL, ATTR_DISCOVERY_ID)
from ..const import ATTR_ADDON, SERVICE_MQTT
from ..exceptions import ServicesError
_LOGGER = logging.getLogger(__name__)
@ -28,62 +27,24 @@ class MQTTService(ServiceInterface):
"""Return data schema of this service."""
return SCHEMA_SERVICE_MQTT
@property
def provider(self):
"""Return name of service provider."""
return self._data.get(ATTR_PROVIDER)
@property
def hass_config(self):
"""Return Home Assistant MQTT config."""
if not self.enabled:
return None
hass_config = {
'host': self._data[ATTR_HOST],
'port': self._data[ATTR_PORT],
'protocol': self._data[ATTR_PROTOCOL]
}
if ATTR_USERNAME in self._data:
hass_config['user']: self._data[ATTR_USERNAME]
if ATTR_PASSWORD in self._data:
hass_config['password']: self._data[ATTR_PASSWORD]
return hass_config
def set_service_data(self, provider, data):
def set_service_data(self, addon, data):
"""Write the data into service object."""
if self.enabled:
_LOGGER.error("It is already a MQTT in use from %s", self.provider)
return False
_LOGGER.error(
"It is already a MQTT in use from %s", self._data[ATTR_ADDON])
raise ServicesError()
self._data.update(data)
self._data[ATTR_PROVIDER] = provider
self._data[ATTR_ADDON] = addon.slug
if provider == 'homeassistant':
_LOGGER.info("Use MQTT settings from Home Assistant")
self.save()
return True
# Discover MQTT to Home Assistant
message = self.sys_discovery.send(
provider, SERVICE_MQTT, None, self.hass_config)
self._data[ATTR_DISCOVERY_ID] = message.uuid
_LOGGER.info("Set %s as service provider for mqtt", addon.slug)
self.save()
return True
def del_service_data(self, provider):
def del_service_data(self, addon):
"""Remove the data from service object."""
if not self.enabled:
_LOGGER.warning("Can't remove not exists services")
return False
discovery_id = self._data.get(ATTR_DISCOVERY_ID)
if discovery_id:
self.sys_discovery.remove(
self.sys_discovery.get(discovery_id))
raise ServicesError()
self._data.clear()
self.save()
return True

View File

@ -1,20 +1,40 @@
"""Validate services schema."""
import re
import voluptuous as vol
from ..const import (
SERVICE_MQTT, ATTR_HOST, ATTR_PORT, ATTR_PASSWORD, ATTR_USERNAME, ATTR_SSL,
ATTR_PROVIDER, ATTR_PROTOCOL, ATTR_DISCOVERY, ATTR_COMPONENT, ATTR_UUID,
ATTR_PLATFORM, ATTR_CONFIG, ATTR_DISCOVERY_ID)
ATTR_ADDON, ATTR_PROTOCOL, ATTR_DISCOVERY, ATTR_COMPONENT, ATTR_UUID,
ATTR_PLATFORM, ATTR_CONFIG, ATTR_SERVICE)
from ..validate import NETWORK_PORT
UUID_MATCH = re.compile(r"^[0-9a-f]{32}$")
SERVICE_ALL = [
SERVICE_MQTT
]
def schema_or(schema):
"""Allow schema or empty."""
def _wrapper(value):
"""Wrapper for validator."""
if not value:
return value
return schema(value)
return _wrapper
SCHEMA_DISCOVERY = vol.Schema([
vol.Schema({
vol.Required(ATTR_UUID): vol.Match(r"^[0-9a-f]{32}$"),
vol.Required(ATTR_PROVIDER): vol.Coerce(str),
vol.Required(ATTR_UUID): vol.Match(UUID_MATCH),
vol.Required(ATTR_ADDON): vol.Coerce(str),
vol.Required(ATTR_SERVICE): vol.In(SERVICE_ALL),
vol.Required(ATTR_COMPONENT): vol.Coerce(str),
vol.Required(ATTR_PLATFORM): vol.Any(None, vol.Coerce(str)),
vol.Required(ATTR_CONFIG): vol.Any(None, dict),
vol.Required(ATTR_PLATFORM): vol.Maybe(vol.Coerce(str)),
vol.Required(ATTR_CONFIG): vol.Maybe(dict),
}, extra=vol.REMOVE_EXTRA)
])
@ -32,12 +52,16 @@ SCHEMA_SERVICE_MQTT = vol.Schema({
SCHEMA_CONFIG_MQTT = SCHEMA_SERVICE_MQTT.extend({
vol.Required(ATTR_PROVIDER): vol.Coerce(str),
vol.Optional(ATTR_DISCOVERY_ID): vol.Match(r"^[0-9a-f]{32}$"),
vol.Required(ATTR_ADDON): vol.Coerce(str),
})
SCHEMA_SERVICES_FILE = vol.Schema({
vol.Optional(SERVICE_MQTT, default=dict): vol.Any({}, SCHEMA_CONFIG_MQTT),
vol.Optional(ATTR_DISCOVERY, default=list): vol.Any([], SCHEMA_DISCOVERY),
vol.Optional(SERVICE_MQTT, default=dict): schema_or(SCHEMA_CONFIG_MQTT),
vol.Optional(ATTR_DISCOVERY, default=list): schema_or(SCHEMA_DISCOVERY),
}, extra=vol.REMOVE_EXTRA)
DISCOVERY_SERVICES = {
SERVICE_MQTT: SCHEMA_SERVICE_MQTT,
}