Add Black

This commit is contained in:
Paulus Schoutsen 2019-07-30 16:59:12 -07:00
parent 0490167a12
commit da05dfe708
16 changed files with 401 additions and 272 deletions

View File

@ -17,6 +17,10 @@
"python.pythonPath": "/usr/local/bin/python", "python.pythonPath": "/usr/local/bin/python",
"python.linting.pylintEnabled": true, "python.linting.pylintEnabled": true,
"python.linting.enabled": true, "python.linting.enabled": true,
"python.formatting.provider": "black",
"editor.formatOnPaste": false,
"editor.formatOnSave": true,
"editor.formatOnType": true,
"files.trimTrailingWhitespace": true, "files.trimTrailingWhitespace": true,
"editor.rulers": [80], "editor.rulers": [80],
"terminal.integrated.shell.linux": "/bin/bash", "terminal.integrated.shell.linux": "/bin/bash",

8
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,8 @@
repos:
- repo: https://github.com/python/black
rev: 19.3b0
hooks:
- id: black
args:
- --safe
- --quiet

View File

@ -38,7 +38,7 @@ stages:
python -m venv venv python -m venv venv
. venv/bin/activate . venv/bin/activate
pip install -r requirements_test.txt pip install -r requirements_test.txt -c homeassistant/package_constraints.txt
displayName: 'Setup Env' displayName: 'Setup Env'
- script: | - script: |
. venv/bin/activate . venv/bin/activate
@ -63,6 +63,21 @@ stages:
. venv/bin/activate . venv/bin/activate
./script/gen_requirements_all.py validate ./script/gen_requirements_all.py validate
displayName: 'requirements_all validate' displayName: 'requirements_all validate'
- job: 'CheckFormat'
pool:
vmImage: 'ubuntu-latest'
container: $[ variables['PythonMain'] ]
steps:
- script: |
python -m venv venv
. venv/bin/activate
pip install -r requirements_test.txt -c homeassistant/package_constraints.txt
displayName: 'Setup Env'
- script: |
. venv/bin/activate
./script/check_format
displayName: 'Check Black formatting'
- stage: 'Tests' - stage: 'Tests'
dependsOn: dependsOn:

View File

@ -21,42 +21,42 @@ from homeassistant.helpers.entity import Entity
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
ATTR_STOP_ID = 'stop_id' ATTR_STOP_ID = "stop_id"
ATTR_STOP_NAME = 'stop' ATTR_STOP_NAME = "stop"
ATTR_ROUTE = 'route' ATTR_ROUTE = "route"
ATTR_TYPE = 'type' ATTR_TYPE = "type"
ATTR_DIRECTION = "direction" ATTR_DIRECTION = "direction"
ATTR_DUE_IN = 'due_in' ATTR_DUE_IN = "due_in"
ATTR_DUE_AT = 'due_at' ATTR_DUE_AT = "due_at"
ATTR_NEXT_UP = 'next_departures' ATTR_NEXT_UP = "next_departures"
ATTRIBUTION = "Data provided by rejseplanen.dk" ATTRIBUTION = "Data provided by rejseplanen.dk"
CONF_STOP_ID = 'stop_id' CONF_STOP_ID = "stop_id"
CONF_ROUTE = 'route' CONF_ROUTE = "route"
CONF_DIRECTION = 'direction' CONF_DIRECTION = "direction"
CONF_DEPARTURE_TYPE = 'departure_type' CONF_DEPARTURE_TYPE = "departure_type"
DEFAULT_NAME = 'Next departure' DEFAULT_NAME = "Next departure"
ICON = 'mdi:bus' ICON = "mdi:bus"
SCAN_INTERVAL = timedelta(minutes=1) SCAN_INTERVAL = timedelta(minutes=1)
BUS_TYPES = ['BUS', 'EXB', 'TB'] BUS_TYPES = ["BUS", "EXB", "TB"]
TRAIN_TYPES = ['LET', 'S', 'REG', 'IC', 'LYN', 'TOG'] TRAIN_TYPES = ["LET", "S", "REG", "IC", "LYN", "TOG"]
METRO_TYPES = ['M'] METRO_TYPES = ["M"]
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Required(CONF_STOP_ID): cv.string, {
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Required(CONF_STOP_ID): cv.string,
vol.Optional(CONF_ROUTE, default=[]): vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
vol.All(cv.ensure_list, [cv.string]), vol.Optional(CONF_ROUTE, default=[]): vol.All(cv.ensure_list, [cv.string]),
vol.Optional(CONF_DIRECTION, default=[]): vol.Optional(CONF_DIRECTION, default=[]): vol.All(cv.ensure_list, [cv.string]),
vol.All(cv.ensure_list, [cv.string]), vol.Optional(CONF_DEPARTURE_TYPE, default=[]): vol.All(
vol.Optional(CONF_DEPARTURE_TYPE, default=[]): cv.ensure_list, [vol.In([*BUS_TYPES, *TRAIN_TYPES, *METRO_TYPES])]
vol.All(cv.ensure_list, ),
[vol.In([*BUS_TYPES, *TRAIN_TYPES, *METRO_TYPES])]) }
}) )
def due_in_minutes(timestamp): def due_in_minutes(timestamp):
@ -64,8 +64,9 @@ def due_in_minutes(timestamp):
The timestamp should be in the format day.month.year hour:minute The timestamp should be in the format day.month.year hour:minute
""" """
diff = datetime.strptime( diff = datetime.strptime(timestamp, "%d.%m.%y %H:%M") - dt_util.now().replace(
timestamp, "%d.%m.%y %H:%M") - dt_util.now().replace(tzinfo=None) tzinfo=None
)
return int(diff.total_seconds() // 60) return int(diff.total_seconds() // 60)
@ -79,8 +80,9 @@ def setup_platform(hass, config, add_devices, discovery_info=None):
departure_type = config[CONF_DEPARTURE_TYPE] departure_type = config[CONF_DEPARTURE_TYPE]
data = PublicTransportData(stop_id, route, direction, departure_type) data = PublicTransportData(stop_id, route, direction, departure_type)
add_devices([RejseplanenTransportSensor( add_devices(
data, stop_id, route, direction, name)], True) [RejseplanenTransportSensor(data, stop_id, route, direction, name)], True
)
class RejseplanenTransportSensor(Entity): class RejseplanenTransportSensor(Entity):
@ -124,14 +126,14 @@ class RejseplanenTransportSensor(Entity):
ATTR_STOP_NAME: self._times[0][ATTR_STOP_NAME], ATTR_STOP_NAME: self._times[0][ATTR_STOP_NAME],
ATTR_STOP_ID: self._stop_id, ATTR_STOP_ID: self._stop_id,
ATTR_ATTRIBUTION: ATTRIBUTION, ATTR_ATTRIBUTION: ATTRIBUTION,
ATTR_NEXT_UP: next_up ATTR_NEXT_UP: next_up,
} }
return {k: v for k, v in params.items() if v} return {k: v for k, v in params.items() if v}
@property @property
def unit_of_measurement(self): def unit_of_measurement(self):
"""Return the unit this state is expressed in.""" """Return the unit this state is expressed in."""
return 'min' return "min"
@property @property
def icon(self): def icon(self):
@ -148,7 +150,7 @@ class RejseplanenTransportSensor(Entity):
pass pass
class PublicTransportData(): class PublicTransportData:
"""The Class for handling the data retrieval.""" """The Class for handling the data retrieval."""
def __init__(self, stop_id, route, direction, departure_type): def __init__(self, stop_id, route, direction, departure_type):
@ -161,16 +163,21 @@ class PublicTransportData():
def empty_result(self): def empty_result(self):
"""Object returned when no departures are found.""" """Object returned when no departures are found."""
return [{ATTR_DUE_IN: 'n/a', return [
ATTR_DUE_AT: 'n/a', {
ATTR_TYPE: 'n/a', ATTR_DUE_IN: "n/a",
ATTR_ROUTE: self.route, ATTR_DUE_AT: "n/a",
ATTR_DIRECTION: 'n/a', ATTR_TYPE: "n/a",
ATTR_STOP_NAME: 'n/a'}] ATTR_ROUTE: self.route,
ATTR_DIRECTION: "n/a",
ATTR_STOP_NAME: "n/a",
}
]
def update(self): def update(self):
"""Get the latest data from rejseplanen.""" """Get the latest data from rejseplanen."""
import rjpl import rjpl
self.info = [] self.info = []
def intersection(lst1, lst2): def intersection(lst1, lst2):
@ -179,12 +186,9 @@ class PublicTransportData():
# Limit search to selected types, to get more results # Limit search to selected types, to get more results
all_types = not bool(self.departure_type) all_types = not bool(self.departure_type)
use_train = all_types or bool( use_train = all_types or bool(intersection(TRAIN_TYPES, self.departure_type))
intersection(TRAIN_TYPES, self.departure_type)) use_bus = all_types or bool(intersection(BUS_TYPES, self.departure_type))
use_bus = all_types or bool( use_metro = all_types or bool(intersection(METRO_TYPES, self.departure_type))
intersection(BUS_TYPES, self.departure_type))
use_metro = all_types or bool(
intersection(METRO_TYPES, self.departure_type))
try: try:
results = rjpl.departureBoard( results = rjpl.departureBoard(
@ -192,7 +196,7 @@ class PublicTransportData():
timeout=5, timeout=5,
useTrain=use_train, useTrain=use_train,
useBus=use_bus, useBus=use_bus,
useMetro=use_metro useMetro=use_metro,
) )
except rjpl.rjplAPIError as error: except rjpl.rjplAPIError as error:
_LOGGER.debug("API returned error: %s", error) _LOGGER.debug("API returned error: %s", error)
@ -204,36 +208,40 @@ class PublicTransportData():
return return
# Filter result # Filter result
results = [d for d in results if 'cancelled' not in d] results = [d for d in results if "cancelled" not in d]
if self.route: if self.route:
results = [d for d in results if d['name'] in self.route] results = [d for d in results if d["name"] in self.route]
if self.direction: if self.direction:
results = [d for d in results if d['direction'] in self.direction] results = [d for d in results if d["direction"] in self.direction]
if self.departure_type: if self.departure_type:
results = [d for d in results if d['type'] in self.departure_type] results = [d for d in results if d["type"] in self.departure_type]
for item in results: for item in results:
route = item.get('name') route = item.get("name")
due_at_date = item.get('rtDate') due_at_date = item.get("rtDate")
due_at_time = item.get('rtTime') due_at_time = item.get("rtTime")
if due_at_date is None: if due_at_date is None:
due_at_date = item.get('date') # Scheduled date due_at_date = item.get("date") # Scheduled date
if due_at_time is None: if due_at_time is None:
due_at_time = item.get('time') # Scheduled time due_at_time = item.get("time") # Scheduled time
if (due_at_date is not None and if (
due_at_time is not None and due_at_date is not None
route is not None): and due_at_time is not None
due_at = '{} {}'.format(due_at_date, due_at_time) and route is not None
):
due_at = "{} {}".format(due_at_date, due_at_time)
departure_data = {ATTR_DUE_IN: due_in_minutes(due_at), departure_data = {
ATTR_DUE_AT: due_at, ATTR_DUE_IN: due_in_minutes(due_at),
ATTR_TYPE: item.get('type'), ATTR_DUE_AT: due_at,
ATTR_ROUTE: route, ATTR_TYPE: item.get("type"),
ATTR_DIRECTION: item.get('direction'), ATTR_ROUTE: route,
ATTR_STOP_NAME: item.get('stop')} ATTR_DIRECTION: item.get("direction"),
ATTR_STOP_NAME: item.get("stop"),
}
self.info.append(departure_data) self.info.append(departure_data)
if not self.info: if not self.info:

View File

@ -14,11 +14,19 @@ from random import uniform
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from ..helpers import ( from ..helpers import (
configure_reporting, construct_unique_id, configure_reporting,
safe_read, get_attr_id_by_name, bind_cluster, LogMixin) construct_unique_id,
safe_read,
get_attr_id_by_name,
bind_cluster,
LogMixin,
)
from ..const import ( from ..const import (
REPORT_CONFIG_DEFAULT, SIGNAL_ATTR_UPDATED, ATTRIBUTE_CHANNEL, REPORT_CONFIG_DEFAULT,
EVENT_RELAY_CHANNEL, ZDO_CHANNEL SIGNAL_ATTR_UPDATED,
ATTRIBUTE_CHANNEL,
EVENT_RELAY_CHANNEL,
ZDO_CHANNEL,
) )
from ..registries import CLUSTER_REPORT_CONFIGS from ..registries import CLUSTER_REPORT_CONFIGS
@ -33,32 +41,33 @@ def parse_and_log_command(channel, tsn, command_id, args):
cmd, cmd,
args, args,
channel.cluster.cluster_id, channel.cluster.cluster_id,
tsn tsn,
) )
return cmd return cmd
def decorate_command(channel, command): def decorate_command(channel, command):
"""Wrap a cluster command to make it safe.""" """Wrap a cluster command to make it safe."""
@wraps(command) @wraps(command)
async def wrapper(*args, **kwds): async def wrapper(*args, **kwds):
from zigpy.exceptions import DeliveryError from zigpy.exceptions import DeliveryError
try: try:
result = await command(*args, **kwds) result = await command(*args, **kwds)
channel.debug("executed command: %s %s %s %s", channel.debug(
command.__name__, "executed command: %s %s %s %s",
"{}: {}".format("with args", args), command.__name__,
"{}: {}".format("with kwargs", kwds), "{}: {}".format("with args", args),
"{}: {}".format("and result", result)) "{}: {}".format("with kwargs", kwds),
"{}: {}".format("and result", result),
)
return result return result
except (DeliveryError, Timeout) as ex: except (DeliveryError, Timeout) as ex:
channel.debug( channel.debug("command failed: %s exception: %s", command.__name__, str(ex))
"command failed: %s exception: %s",
command.__name__,
str(ex)
)
return ex return ex
return wrapper return wrapper
@ -80,13 +89,12 @@ class ZigbeeChannel(LogMixin):
self._channel_name = cluster.ep_attribute self._channel_name = cluster.ep_attribute
if self.CHANNEL_NAME: if self.CHANNEL_NAME:
self._channel_name = self.CHANNEL_NAME self._channel_name = self.CHANNEL_NAME
self._generic_id = 'channel_0x{:04x}'.format(cluster.cluster_id) self._generic_id = "channel_0x{:04x}".format(cluster.cluster_id)
self._cluster = cluster self._cluster = cluster
self._zha_device = device self._zha_device = device
self._unique_id = construct_unique_id(cluster) self._unique_id = construct_unique_id(cluster)
self._report_config = CLUSTER_REPORT_CONFIGS.get( self._report_config = CLUSTER_REPORT_CONFIGS.get(
self._cluster.cluster_id, self._cluster.cluster_id, [{"attr": 0, "config": REPORT_CONFIG_DEFAULT}]
[{'attr': 0, 'config': REPORT_CONFIG_DEFAULT}]
) )
self._status = ChannelStatus.CREATED self._status = ChannelStatus.CREATED
self._cluster.add_listener(self) self._cluster.add_listener(self)
@ -130,21 +138,24 @@ class ZigbeeChannel(LogMixin):
manufacturer = None manufacturer = None
manufacturer_code = self._zha_device.manufacturer_code manufacturer_code = self._zha_device.manufacturer_code
# Xiaomi devices don't need this and it disrupts pairing # Xiaomi devices don't need this and it disrupts pairing
if self._zha_device.manufacturer != 'LUMI': if self._zha_device.manufacturer != "LUMI":
if self.cluster.cluster_id >= 0xfc00 and manufacturer_code: if self.cluster.cluster_id >= 0xFC00 and manufacturer_code:
manufacturer = manufacturer_code manufacturer = manufacturer_code
await bind_cluster(self._unique_id, self.cluster) await bind_cluster(self._unique_id, self.cluster)
if not self.cluster.bind_only: if not self.cluster.bind_only:
for report_config in self._report_config: for report_config in self._report_config:
attr = report_config.get('attr') attr = report_config.get("attr")
min_report_interval, max_report_interval, change = \ min_report_interval, max_report_interval, change = report_config.get(
report_config.get('config') "config"
)
await configure_reporting( await configure_reporting(
self._unique_id, self.cluster, attr, self._unique_id,
self.cluster,
attr,
min_report=min_report_interval, min_report=min_report_interval,
max_report=max_report_interval, max_report=max_report_interval,
reportable_change=change, reportable_change=change,
manufacturer=manufacturer manufacturer=manufacturer,
) )
await asyncio.sleep(uniform(0.1, 0.5)) await asyncio.sleep(uniform(0.1, 0.5))
@ -153,7 +164,7 @@ class ZigbeeChannel(LogMixin):
async def async_initialize(self, from_cache): async def async_initialize(self, from_cache):
"""Initialize channel.""" """Initialize channel."""
self.debug('initializing channel: from_cache: %s', from_cache) self.debug("initializing channel: from_cache: %s", from_cache)
self._status = ChannelStatus.INITIALIZED self._status = ChannelStatus.INITIALIZED
@callback @callback
@ -175,13 +186,13 @@ class ZigbeeChannel(LogMixin):
def zha_send_event(self, cluster, command, args): def zha_send_event(self, cluster, command, args):
"""Relay events to hass.""" """Relay events to hass."""
self._zha_device.hass.bus.async_fire( self._zha_device.hass.bus.async_fire(
'zha_event', "zha_event",
{ {
'unique_id': self._unique_id, "unique_id": self._unique_id,
'device_ieee': str(self._zha_device.ieee), "device_ieee": str(self._zha_device.ieee),
'command': command, "command": command,
'args': args "args": args,
} },
) )
async def async_update(self): async def async_update(self):
@ -192,14 +203,14 @@ class ZigbeeChannel(LogMixin):
"""Get the value for an attribute.""" """Get the value for an attribute."""
manufacturer = None manufacturer = None
manufacturer_code = self._zha_device.manufacturer_code manufacturer_code = self._zha_device.manufacturer_code
if self.cluster.cluster_id >= 0xfc00 and manufacturer_code: if self.cluster.cluster_id >= 0xFC00 and manufacturer_code:
manufacturer = manufacturer_code manufacturer = manufacturer_code
result = await safe_read( result = await safe_read(
self._cluster, self._cluster,
[attribute], [attribute],
allow_cache=from_cache, allow_cache=from_cache,
only_cache=from_cache, only_cache=from_cache,
manufacturer=manufacturer manufacturer=manufacturer,
) )
return result.get(attribute) return result.get(attribute)
@ -211,14 +222,10 @@ class ZigbeeChannel(LogMixin):
def __getattr__(self, name): def __getattr__(self, name):
"""Get attribute or a decorated cluster command.""" """Get attribute or a decorated cluster command."""
if hasattr(self._cluster, name) and callable( if hasattr(self._cluster, name) and callable(getattr(self._cluster, name)):
getattr(self._cluster, name)):
command = getattr(self._cluster, name) command = getattr(self._cluster, name)
command.__name__ = name command.__name__ = name
return decorate_command( return decorate_command(self, command)
self,
command
)
return self.__getattribute__(name) return self.__getattribute__(name)
@ -230,7 +237,7 @@ class AttributeListeningChannel(ZigbeeChannel):
def __init__(self, cluster, device): def __init__(self, cluster, device):
"""Initialize AttributeListeningChannel.""" """Initialize AttributeListeningChannel."""
super().__init__(cluster, device) super().__init__(cluster, device)
attr = self._report_config[0].get('attr') attr = self._report_config[0].get("attr")
if isinstance(attr, str): if isinstance(attr, str):
self.value_attribute = get_attr_id_by_name(self.cluster, attr) self.value_attribute = get_attr_id_by_name(self.cluster, attr)
else: else:
@ -243,13 +250,14 @@ class AttributeListeningChannel(ZigbeeChannel):
async_dispatcher_send( async_dispatcher_send(
self._zha_device.hass, self._zha_device.hass,
"{}_{}".format(self.unique_id, SIGNAL_ATTR_UPDATED), "{}_{}".format(self.unique_id, SIGNAL_ATTR_UPDATED),
value value,
) )
async def async_initialize(self, from_cache): async def async_initialize(self, from_cache):
"""Initialize listener.""" """Initialize listener."""
await self.get_attribute_value( await self.get_attribute_value(
self._report_config[0].get('attr'), from_cache=from_cache) self._report_config[0].get("attr"), from_cache=from_cache
)
await super().async_initialize(from_cache) await super().async_initialize(from_cache)
@ -293,7 +301,8 @@ class ZDOChannel(LogMixin):
async def async_initialize(self, from_cache): async def async_initialize(self, from_cache):
"""Initialize channel.""" """Initialize channel."""
entry = self._zha_device.gateway.zha_storage.async_get_or_create( entry = self._zha_device.gateway.zha_storage.async_get_or_create(
self._zha_device) self._zha_device
)
self.debug("entry loaded from storage: %s", entry) self.debug("entry loaded from storage: %s", entry)
self._status = ChannelStatus.INITIALIZED self._status = ChannelStatus.INITIALIZED
@ -320,21 +329,19 @@ class EventRelayChannel(ZigbeeChannel):
self._cluster, self._cluster,
SIGNAL_ATTR_UPDATED, SIGNAL_ATTR_UPDATED,
{ {
'attribute_id': attrid, "attribute_id": attrid,
'attribute_name': self._cluster.attributes.get( "attribute_name": self._cluster.attributes.get(attrid, ["Unknown"])[0],
attrid, "value": value,
['Unknown'])[0], },
'value': value
}
) )
@callback @callback
def cluster_command(self, tsn, command_id, args): def cluster_command(self, tsn, command_id, args):
"""Handle a cluster command received on this cluster.""" """Handle a cluster command received on this cluster."""
if self._cluster.server_commands is not None and \ if (
self._cluster.server_commands.get(command_id) is not None: self._cluster.server_commands is not None
and self._cluster.server_commands.get(command_id) is not None
):
self.zha_send_event( self.zha_send_event(
self._cluster, self._cluster, self._cluster.server_commands.get(command_id)[0], args
self._cluster.server_commands.get(command_id)[0],
args
) )

View File

@ -12,18 +12,46 @@ import time
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, async_dispatcher_send) async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.event import async_track_time_interval
from .channels import EventRelayChannel from .channels import EventRelayChannel
from .const import ( from .const import (
ATTR_ARGS, ATTR_ATTRIBUTE, ATTR_CLUSTER_ID, ATTR_COMMAND, ATTR_ARGS,
ATTR_COMMAND_TYPE, ATTR_ENDPOINT_ID, ATTR_MANUFACTURER, ATTR_VALUE, ATTR_ATTRIBUTE,
BATTERY_OR_UNKNOWN, CLIENT_COMMANDS, IEEE, IN, MAINS_POWERED, ATTR_CLUSTER_ID,
MANUFACTURER_CODE, MODEL, NAME, NWK, OUT, POWER_CONFIGURATION_CHANNEL, ATTR_COMMAND,
POWER_SOURCE, QUIRK_APPLIED, QUIRK_CLASS, SERVER, SERVER_COMMANDS, ATTR_COMMAND_TYPE,
SIGNAL_AVAILABLE, UNKNOWN_MANUFACTURER, UNKNOWN_MODEL, ZDO_CHANNEL, ATTR_ENDPOINT_ID,
LQI, RSSI, LAST_SEEN, ATTR_AVAILABLE) ATTR_MANUFACTURER,
ATTR_VALUE,
BATTERY_OR_UNKNOWN,
CLIENT_COMMANDS,
IEEE,
IN,
MAINS_POWERED,
MANUFACTURER_CODE,
MODEL,
NAME,
NWK,
OUT,
POWER_CONFIGURATION_CHANNEL,
POWER_SOURCE,
QUIRK_APPLIED,
QUIRK_CLASS,
SERVER,
SERVER_COMMANDS,
SIGNAL_AVAILABLE,
UNKNOWN_MANUFACTURER,
UNKNOWN_MODEL,
ZDO_CHANNEL,
LQI,
RSSI,
LAST_SEEN,
ATTR_AVAILABLE,
)
from .helpers import LogMixin from .helpers import LogMixin
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -51,22 +79,20 @@ class ZHADevice(LogMixin):
self._all_channels = [] self._all_channels = []
self._available = False self._available = False
self._available_signal = "{}_{}_{}".format( self._available_signal = "{}_{}_{}".format(
self.name, self.ieee, SIGNAL_AVAILABLE) self.name, self.ieee, SIGNAL_AVAILABLE
)
self._unsub = async_dispatcher_connect( self._unsub = async_dispatcher_connect(
self.hass, self.hass, self._available_signal, self.async_initialize
self._available_signal,
self.async_initialize
) )
from zigpy.quirks import CustomDevice from zigpy.quirks import CustomDevice
self.quirk_applied = isinstance(self._zigpy_device, CustomDevice) self.quirk_applied = isinstance(self._zigpy_device, CustomDevice)
self.quirk_class = "{}.{}".format( self.quirk_class = "{}.{}".format(
self._zigpy_device.__class__.__module__, self._zigpy_device.__class__.__module__,
self._zigpy_device.__class__.__name__ self._zigpy_device.__class__.__name__,
) )
self._available_check = async_track_time_interval( self._available_check = async_track_time_interval(
self.hass, self.hass, self._check_available, _UPDATE_ALIVE_INTERVAL
self._check_available,
_UPDATE_ALIVE_INTERVAL
) )
self.status = DeviceStatus.CREATED self.status = DeviceStatus.CREATED
@ -184,15 +210,9 @@ class ZHADevice(LogMixin):
"""Set sensor availability.""" """Set sensor availability."""
if self._available != available and available: if self._available != available and available:
# Update the state the first time the device comes online # Update the state the first time the device comes online
async_dispatcher_send( async_dispatcher_send(self.hass, self._available_signal, False)
self.hass,
self._available_signal,
False
)
async_dispatcher_send( async_dispatcher_send(
self.hass, self.hass, "{}_{}".format(self._available_signal, "entity"), available
"{}_{}".format(self._available_signal, 'entity'),
available
) )
self._available = available self._available = available
@ -215,14 +235,16 @@ class ZHADevice(LogMixin):
LQI: self.lqi, LQI: self.lqi,
RSSI: self.rssi, RSSI: self.rssi,
LAST_SEEN: update_time, LAST_SEEN: update_time,
ATTR_AVAILABLE: self.available ATTR_AVAILABLE: self.available,
} }
def add_cluster_channel(self, cluster_channel): def add_cluster_channel(self, cluster_channel):
"""Add cluster channel to device.""" """Add cluster channel to device."""
# only keep 1 power configuration channel # only keep 1 power configuration channel
if cluster_channel.name is POWER_CONFIGURATION_CHANNEL and \ if (
POWER_CONFIGURATION_CHANNEL in self.cluster_channels: cluster_channel.name is POWER_CONFIGURATION_CHANNEL
and POWER_CONFIGURATION_CHANNEL in self.cluster_channels
):
return return
if isinstance(cluster_channel, EventRelayChannel): if isinstance(cluster_channel, EventRelayChannel):
@ -249,10 +271,9 @@ class ZHADevice(LogMixin):
def get_key(channel): def get_key(channel):
channel_key = "ZDO" channel_key = "ZDO"
if hasattr(channel.cluster, 'cluster_id'): if hasattr(channel.cluster, "cluster_id"):
channel_key = "{}_{}".format( channel_key = "{}_{}".format(
channel.cluster.endpoint.endpoint_id, channel.cluster.endpoint.endpoint_id, channel.cluster.cluster_id
channel.cluster.cluster_id
) )
return channel_key return channel_key
@ -273,21 +294,23 @@ class ZHADevice(LogMixin):
async def async_configure(self): async def async_configure(self):
"""Configure the device.""" """Configure the device."""
self.debug('started configuration') self.debug("started configuration")
await self._execute_channel_tasks( await self._execute_channel_tasks(
self.get_channels_to_configure(), 'async_configure') self.get_channels_to_configure(), "async_configure"
self.debug('completed configuration') )
self.debug("completed configuration")
entry = self.gateway.zha_storage.async_create_or_update(self) entry = self.gateway.zha_storage.async_create_or_update(self)
self.debug('stored in registry: %s', entry) self.debug("stored in registry: %s", entry)
async def async_initialize(self, from_cache=False): async def async_initialize(self, from_cache=False):
"""Initialize channels.""" """Initialize channels."""
self.debug('started initialization') self.debug("started initialization")
await self._execute_channel_tasks( await self._execute_channel_tasks(
self.all_channels, 'async_initialize', from_cache) self.all_channels, "async_initialize", from_cache
self.debug('power source: %s', self.power_source) )
self.debug("power source: %s", self.power_source)
self.status = DeviceStatus.INITIALIZED self.status = DeviceStatus.INITIALIZED
self.debug('completed initialization') self.debug("completed initialization")
async def _execute_channel_tasks(self, channels, task_name, *args): async def _execute_channel_tasks(self, channels, task_name, *args):
"""Gather and execute a set of CHANNEL tasks.""" """Gather and execute a set of CHANNEL tasks."""
@ -299,11 +322,12 @@ class ZHADevice(LogMixin):
# pylint: disable=E1111 # pylint: disable=E1111
if zdo_task is None: # We only want to do this once if zdo_task is None: # We only want to do this once
zdo_task = self._async_create_task( zdo_task = self._async_create_task(
semaphore, channel, task_name, *args) semaphore, channel, task_name, *args
)
else: else:
channel_tasks.append( channel_tasks.append(
self._async_create_task( self._async_create_task(semaphore, channel, task_name, *args)
semaphore, channel, task_name, *args)) )
if zdo_task is not None: if zdo_task is not None:
await zdo_task await zdo_task
await asyncio.gather(*channel_tasks) await asyncio.gather(*channel_tasks)
@ -332,10 +356,8 @@ class ZHADevice(LogMixin):
def async_get_clusters(self): def async_get_clusters(self):
"""Get all clusters for this device.""" """Get all clusters for this device."""
return { return {
ep_id: { ep_id: {IN: endpoint.in_clusters, OUT: endpoint.out_clusters}
IN: endpoint.in_clusters, for (ep_id, endpoint) in self._zigpy_device.endpoints.items()
OUT: endpoint.out_clusters
} for (ep_id, endpoint) in self._zigpy_device.endpoints.items()
if ep_id != 0 if ep_id != 0
} }
@ -343,15 +365,11 @@ class ZHADevice(LogMixin):
def async_get_std_clusters(self): def async_get_std_clusters(self):
"""Get ZHA and ZLL clusters for this device.""" """Get ZHA and ZLL clusters for this device."""
from zigpy.profiles import zha, zll from zigpy.profiles import zha, zll
return { return {
ep_id: { ep_id: {IN: endpoint.in_clusters, OUT: endpoint.out_clusters}
IN: endpoint.in_clusters, for (ep_id, endpoint) in self._zigpy_device.endpoints.items()
OUT: endpoint.out_clusters if ep_id != 0 and endpoint.profile_id in (zha.PROFILE_ID, zll.PROFILE_ID)
} for (ep_id, endpoint) in self._zigpy_device.endpoints.items()
if ep_id != 0 and endpoint.profile_id in (
zha.PROFILE_ID,
zll.PROFILE_ID
)
} }
@callback @callback
@ -361,18 +379,15 @@ class ZHADevice(LogMixin):
return clusters[endpoint_id][cluster_type][cluster_id] return clusters[endpoint_id][cluster_type][cluster_id]
@callback @callback
def async_get_cluster_attributes(self, endpoint_id, cluster_id, def async_get_cluster_attributes(self, endpoint_id, cluster_id, cluster_type=IN):
cluster_type=IN):
"""Get zigbee attributes for specified cluster.""" """Get zigbee attributes for specified cluster."""
cluster = self.async_get_cluster(endpoint_id, cluster_id, cluster = self.async_get_cluster(endpoint_id, cluster_id, cluster_type)
cluster_type)
if cluster is None: if cluster is None:
return None return None
return cluster.attributes return cluster.attributes
@callback @callback
def async_get_cluster_commands(self, endpoint_id, cluster_id, def async_get_cluster_commands(self, endpoint_id, cluster_id, cluster_type=IN):
cluster_type=IN):
"""Get zigbee commands for specified cluster.""" """Get zigbee commands for specified cluster."""
cluster = self.async_get_cluster(endpoint_id, cluster_id, cluster_type) cluster = self.async_get_cluster(endpoint_id, cluster_id, cluster_type)
if cluster is None: if cluster is None:
@ -382,64 +397,77 @@ class ZHADevice(LogMixin):
SERVER_COMMANDS: cluster.server_commands, SERVER_COMMANDS: cluster.server_commands,
} }
async def write_zigbee_attribute(self, endpoint_id, cluster_id, async def write_zigbee_attribute(
attribute, value, cluster_type=IN, self,
manufacturer=None): endpoint_id,
cluster_id,
attribute,
value,
cluster_type=IN,
manufacturer=None,
):
"""Write a value to a zigbee attribute for a cluster in this entity.""" """Write a value to a zigbee attribute for a cluster in this entity."""
cluster = self.async_get_cluster(endpoint_id, cluster_id, cluster_type) cluster = self.async_get_cluster(endpoint_id, cluster_id, cluster_type)
if cluster is None: if cluster is None:
return None return None
from zigpy.exceptions import DeliveryError from zigpy.exceptions import DeliveryError
try: try:
response = await cluster.write_attributes( response = await cluster.write_attributes(
{attribute: value}, {attribute: value}, manufacturer=manufacturer
manufacturer=manufacturer
) )
self.debug( self.debug(
'set: %s for attr: %s to cluster: %s for ept: %s - res: %s', "set: %s for attr: %s to cluster: %s for ept: %s - res: %s",
value, value,
attribute, attribute,
cluster_id, cluster_id,
endpoint_id, endpoint_id,
response response,
) )
return response return response
except DeliveryError as exc: except DeliveryError as exc:
self.debug( self.debug(
'failed to set attribute: %s %s %s %s %s', "failed to set attribute: %s %s %s %s %s",
'{}: {}'.format(ATTR_VALUE, value), "{}: {}".format(ATTR_VALUE, value),
'{}: {}'.format(ATTR_ATTRIBUTE, attribute), "{}: {}".format(ATTR_ATTRIBUTE, attribute),
'{}: {}'.format(ATTR_CLUSTER_ID, cluster_id), "{}: {}".format(ATTR_CLUSTER_ID, cluster_id),
'{}: {}'.format(ATTR_ENDPOINT_ID, endpoint_id), "{}: {}".format(ATTR_ENDPOINT_ID, endpoint_id),
exc exc,
) )
return None return None
async def issue_cluster_command(self, endpoint_id, cluster_id, command, async def issue_cluster_command(
command_type, args, cluster_type=IN, self,
manufacturer=None): endpoint_id,
cluster_id,
command,
command_type,
args,
cluster_type=IN,
manufacturer=None,
):
"""Issue a command against specified zigbee cluster on this entity.""" """Issue a command against specified zigbee cluster on this entity."""
cluster = self.async_get_cluster(endpoint_id, cluster_id, cluster_type) cluster = self.async_get_cluster(endpoint_id, cluster_id, cluster_type)
if cluster is None: if cluster is None:
return None return None
response = None response = None
if command_type == SERVER: if command_type == SERVER:
response = await cluster.command(command, *args, response = await cluster.command(
manufacturer=manufacturer, command, *args, manufacturer=manufacturer, expect_reply=True
expect_reply=True) )
else: else:
response = await cluster.client_command(command, *args) response = await cluster.client_command(command, *args)
self.debug( self.debug(
'Issued cluster command: %s %s %s %s %s %s %s', "Issued cluster command: %s %s %s %s %s %s %s",
'{}: {}'.format(ATTR_CLUSTER_ID, cluster_id), "{}: {}".format(ATTR_CLUSTER_ID, cluster_id),
'{}: {}'.format(ATTR_COMMAND, command), "{}: {}".format(ATTR_COMMAND, command),
'{}: {}'.format(ATTR_COMMAND_TYPE, command_type), "{}: {}".format(ATTR_COMMAND_TYPE, command_type),
'{}: {}'.format(ATTR_ARGS, args), "{}: {}".format(ATTR_ARGS, args),
'{}: {}'.format(ATTR_CLUSTER_ID, cluster_type), "{}: {}".format(ATTR_CLUSTER_ID, cluster_type),
'{}: {}'.format(ATTR_MANUFACTURER, manufacturer), "{}: {}".format(ATTR_MANUFACTURER, manufacturer),
'{}: {}'.format(ATTR_ENDPOINT_ID, endpoint_id) "{}: {}".format(ATTR_ENDPOINT_ID, endpoint_id),
) )
return response return response

View File

@ -12,13 +12,19 @@ from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.util import slugify from homeassistant.util import slugify
from .core.const import ( from .core.const import (
ATTR_MANUFACTURER, DATA_ZHA, DATA_ZHA_BRIDGE_ID, DOMAIN, MODEL, NAME, ATTR_MANUFACTURER,
SIGNAL_REMOVE) DATA_ZHA,
DATA_ZHA_BRIDGE_ID,
DOMAIN,
MODEL,
NAME,
SIGNAL_REMOVE,
)
from .core.helpers import LogMixin from .core.helpers import LogMixin
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
ENTITY_SUFFIX = 'entity_suffix' ENTITY_SUFFIX = "entity_suffix"
RESTART_GRACE_PERIOD = 7200 # 2 hours RESTART_GRACE_PERIOD = 7200 # 2 hours
@ -27,29 +33,28 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
_domain = None # Must be overridden by subclasses _domain = None # Must be overridden by subclasses
def __init__(self, unique_id, zha_device, channels, def __init__(self, unique_id, zha_device, channels, skip_entity_id=False, **kwargs):
skip_entity_id=False, **kwargs):
"""Init ZHA entity.""" """Init ZHA entity."""
self._force_update = False self._force_update = False
self._should_poll = False self._should_poll = False
self._unique_id = unique_id self._unique_id = unique_id
if not skip_entity_id: if not skip_entity_id:
ieee = zha_device.ieee ieee = zha_device.ieee
ieeetail = ''.join(['%02x' % (o, ) for o in ieee[-4:]]) ieeetail = "".join(["%02x" % (o,) for o in ieee[-4:]])
self.entity_id = "{}.{}_{}_{}_{}{}".format( self.entity_id = "{}.{}_{}_{}_{}{}".format(
self._domain, self._domain,
slugify(zha_device.manufacturer), slugify(zha_device.manufacturer),
slugify(zha_device.model), slugify(zha_device.model),
ieeetail, ieeetail,
channels[0].cluster.endpoint.endpoint_id, channels[0].cluster.endpoint.endpoint_id,
kwargs.get(ENTITY_SUFFIX, ''), kwargs.get(ENTITY_SUFFIX, ""),
) )
self._state = None self._state = None
self._device_state_attributes = {} self._device_state_attributes = {}
self._zha_device = zha_device self._zha_device = zha_device
self.cluster_channels = {} self.cluster_channels = {}
self._available = False self._available = False
self._component = kwargs['component'] self._component = kwargs["component"]
self._unsubs = [] self._unsubs = []
self.remove_future = None self.remove_future = None
for channel in channels: for channel in channels:
@ -89,15 +94,14 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
def device_info(self): def device_info(self):
"""Return a device description for device registry.""" """Return a device description for device registry."""
zha_device_info = self._zha_device.device_info zha_device_info = self._zha_device.device_info
ieee = zha_device_info['ieee'] ieee = zha_device_info["ieee"]
return { return {
'connections': {(CONNECTION_ZIGBEE, ieee)}, "connections": {(CONNECTION_ZIGBEE, ieee)},
'identifiers': {(DOMAIN, ieee)}, "identifiers": {(DOMAIN, ieee)},
ATTR_MANUFACTURER: zha_device_info[ATTR_MANUFACTURER], ATTR_MANUFACTURER: zha_device_info[ATTR_MANUFACTURER],
MODEL: zha_device_info[MODEL], MODEL: zha_device_info[MODEL],
NAME: zha_device_info[NAME], NAME: zha_device_info[NAME],
'via_device': ( "via_device": (DOMAIN, self.hass.data[DATA_ZHA][DATA_ZHA_BRIDGE_ID]),
DOMAIN, self.hass.data[DATA_ZHA][DATA_ZHA_BRIDGE_ID]),
} }
@property @property
@ -112,9 +116,7 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
def async_update_state_attribute(self, key, value): def async_update_state_attribute(self, key, value):
"""Update a single device state attribute.""" """Update a single device state attribute."""
self._device_state_attributes.update({ self._device_state_attributes.update({key: value})
key: value
})
self.async_schedule_update_ha_state() self.async_schedule_update_ha_state()
def async_set_state(self, state): def async_set_state(self, state):
@ -127,24 +129,34 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
self.remove_future = asyncio.Future() self.remove_future = asyncio.Future()
await self.async_check_recently_seen() await self.async_check_recently_seen()
await self.async_accept_signal( await self.async_accept_signal(
None, "{}_{}".format(self.zha_device.available_signal, 'entity'), None,
"{}_{}".format(self.zha_device.available_signal, "entity"),
self.async_set_available, self.async_set_available,
signal_override=True) signal_override=True,
)
await self.async_accept_signal( await self.async_accept_signal(
None, "{}_{}".format(SIGNAL_REMOVE, str(self.zha_device.ieee)), None,
"{}_{}".format(SIGNAL_REMOVE, str(self.zha_device.ieee)),
self.async_remove, self.async_remove,
signal_override=True signal_override=True,
) )
self._zha_device.gateway.register_entity_reference( self._zha_device.gateway.register_entity_reference(
self._zha_device.ieee, self.entity_id, self._zha_device, self._zha_device.ieee,
self.cluster_channels, self.device_info, self.remove_future) self.entity_id,
self._zha_device,
self.cluster_channels,
self.device_info,
self.remove_future,
)
async def async_check_recently_seen(self): async def async_check_recently_seen(self):
"""Check if the device was seen within the last 2 hours.""" """Check if the device was seen within the last 2 hours."""
last_state = await self.async_get_last_state() last_state = await self.async_get_last_state()
if last_state and self._zha_device.last_seen and ( if (
time.time() - self._zha_device.last_seen < last_state
RESTART_GRACE_PERIOD): and self._zha_device.last_seen
and (time.time() - self._zha_device.last_seen < RESTART_GRACE_PERIOD)
):
self.async_set_available(True) self.async_set_available(True)
if not self.zha_device.is_mains_powered: if not self.zha_device.is_mains_powered:
# mains powered devices will get real time state # mains powered devices will get real time state
@ -167,24 +179,17 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
async def async_update(self): async def async_update(self):
"""Retrieve latest state.""" """Retrieve latest state."""
for channel in self.cluster_channels.values(): for channel in self.cluster_channels.values():
if hasattr(channel, 'async_update'): if hasattr(channel, "async_update"):
await channel.async_update() await channel.async_update()
async def async_accept_signal(self, channel, signal, func, async def async_accept_signal(self, channel, signal, func, signal_override=False):
signal_override=False):
"""Accept a signal from a channel.""" """Accept a signal from a channel."""
unsub = None unsub = None
if signal_override: if signal_override:
unsub = async_dispatcher_connect( unsub = async_dispatcher_connect(self.hass, signal, func)
self.hass,
signal,
func
)
else: else:
unsub = async_dispatcher_connect( unsub = async_dispatcher_connect(
self.hass, self.hass, "{}_{}".format(channel.unique_id, signal), func
"{}_{}".format(channel.unique_id, signal),
func
) )
self._unsubs.append(unsub) self._unsubs.append(unsub)

View File

@ -9,21 +9,29 @@ import random
import string import string
from functools import wraps from functools import wraps
from types import MappingProxyType from types import MappingProxyType
from typing import (Any, Optional, TypeVar, Callable, KeysView, Union, # noqa from typing import (
Iterable, List, Dict, Iterator, Coroutine, MutableSet) Any,
Optional,
TypeVar,
Callable,
KeysView,
Union, # noqa
Iterable,
Coroutine,
)
import slugify as unicode_slug import slugify as unicode_slug
from .dt import as_local, utcnow from .dt import as_local, utcnow
# pylint: disable=invalid-name # pylint: disable=invalid-name
T = TypeVar('T') T = TypeVar("T")
U = TypeVar('U') U = TypeVar("U")
ENUM_T = TypeVar('ENUM_T', bound=enum.Enum) ENUM_T = TypeVar("ENUM_T", bound=enum.Enum)
# pylint: enable=invalid-name # pylint: enable=invalid-name
RE_SANITIZE_FILENAME = re.compile(r'(~|\.\.|/|\\)') RE_SANITIZE_FILENAME = re.compile(r"(~|\.\.|/|\\)")
RE_SANITIZE_PATH = re.compile(r'(~|\.(\.)+)') RE_SANITIZE_PATH = re.compile(r"(~|\.(\.)+)")
def sanitize_filename(filename: str) -> str: def sanitize_filename(filename: str) -> str:
@ -38,23 +46,24 @@ def sanitize_path(path: str) -> str:
def slugify(text: str) -> str: def slugify(text: str) -> str:
"""Slugify a given text.""" """Slugify a given text."""
return unicode_slug.slugify(text, separator='_') # type: ignore return unicode_slug.slugify(text, separator="_") # type: ignore
def repr_helper(inp: Any) -> str: def repr_helper(inp: Any) -> str:
"""Help creating a more readable string representation of objects.""" """Help creating a more readable string representation of objects."""
if isinstance(inp, (dict, MappingProxyType)): if isinstance(inp, (dict, MappingProxyType)):
return ", ".join( return ", ".join(
repr_helper(key)+"="+repr_helper(item) for key, item repr_helper(key) + "=" + repr_helper(item) for key, item in inp.items()
in inp.items()) )
if isinstance(inp, datetime): if isinstance(inp, datetime):
return as_local(inp).isoformat() return as_local(inp).isoformat()
return str(inp) return str(inp)
def convert(value: Optional[T], to_type: Callable[[T], U], def convert(
default: Optional[U] = None) -> Optional[U]: value: Optional[T], to_type: Callable[[T], U], default: Optional[U] = None
) -> Optional[U]:
"""Convert value to to_type, returns default if fails.""" """Convert value to to_type, returns default if fails."""
try: try:
return default if value is None else to_type(value) return default if value is None else to_type(value)
@ -63,8 +72,9 @@ def convert(value: Optional[T], to_type: Callable[[T], U],
return default return default
def ensure_unique_string(preferred_string: str, current_strings: def ensure_unique_string(
Union[Iterable[str], KeysView[str]]) -> str: preferred_string: str, current_strings: Union[Iterable[str], KeysView[str]]
) -> str:
"""Return a string that is not present in current_strings. """Return a string that is not present in current_strings.
If preferred string exists will append _2, _3, .. If preferred string exists will append _2, _3, ..
@ -88,14 +98,14 @@ def get_local_ip() -> str:
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# Use Google Public DNS server to determine own IP # Use Google Public DNS server to determine own IP
sock.connect(('8.8.8.8', 80)) sock.connect(("8.8.8.8", 80))
return sock.getsockname()[0] # type: ignore return sock.getsockname()[0] # type: ignore
except socket.error: except socket.error:
try: try:
return socket.gethostbyname(socket.gethostname()) return socket.gethostbyname(socket.gethostname())
except socket.gaierror: except socket.gaierror:
return '127.0.0.1' return "127.0.0.1"
finally: finally:
sock.close() sock.close()
@ -106,7 +116,7 @@ def get_random_string(length: int = 10) -> str:
generator = random.SystemRandom() generator = random.SystemRandom()
source_chars = string.ascii_letters + string.digits source_chars = string.ascii_letters + string.digits
return ''.join(generator.choice(source_chars) for _ in range(length)) return "".join(generator.choice(source_chars) for _ in range(length))
class OrderedEnum(enum.Enum): class OrderedEnum(enum.Enum):
@ -158,8 +168,9 @@ class Throttle:
Adds a datetime attribute `last_call` to the method. Adds a datetime attribute `last_call` to the method.
""" """
def __init__(self, min_time: timedelta, def __init__(
limit_no_throttle: Optional[timedelta] = None) -> None: self, min_time: timedelta, limit_no_throttle: Optional[timedelta] = None
) -> None:
"""Initialize the throttle.""" """Initialize the throttle."""
self.min_time = min_time self.min_time = min_time
self.limit_no_throttle = limit_no_throttle self.limit_no_throttle = limit_no_throttle
@ -168,10 +179,13 @@ class Throttle:
"""Caller for the throttle.""" """Caller for the throttle."""
# Make sure we return a coroutine if the method is async. # Make sure we return a coroutine if the method is async.
if asyncio.iscoroutinefunction(method): if asyncio.iscoroutinefunction(method):
async def throttled_value() -> None: async def throttled_value() -> None:
"""Stand-in function for when real func is being throttled.""" """Stand-in function for when real func is being throttled."""
return None return None
else: else:
def throttled_value() -> None: # type: ignore def throttled_value() -> None: # type: ignore
"""Stand-in function for when real func is being throttled.""" """Stand-in function for when real func is being throttled."""
return None return None
@ -189,8 +203,10 @@ class Throttle:
# All methods have the classname in their qualname separated by a '.' # All methods have the classname in their qualname separated by a '.'
# Functions have a '.' in their qualname if defined inline, but will # Functions have a '.' in their qualname if defined inline, but will
# be prefixed by '.<locals>.' so we strip that out. # be prefixed by '.<locals>.' so we strip that out.
is_func = (not hasattr(method, '__self__') and is_func = (
'.' not in method.__qualname__.split('.<locals>.')[-1]) not hasattr(method, "__self__")
and "." not in method.__qualname__.split(".<locals>.")[-1]
)
@wraps(method) @wraps(method)
def wrapper(*args: Any, **kwargs: Any) -> Union[Callable, Coroutine]: def wrapper(*args: Any, **kwargs: Any) -> Union[Callable, Coroutine]:
@ -199,14 +215,14 @@ class Throttle:
If we cannot acquire the lock, it is running so return None. If we cannot acquire the lock, it is running so return None.
""" """
# pylint: disable=protected-access # pylint: disable=protected-access
if hasattr(method, '__self__'): if hasattr(method, "__self__"):
host = getattr(method, '__self__') host = getattr(method, "__self__")
elif is_func: elif is_func:
host = wrapper host = wrapper
else: else:
host = args[0] if args else wrapper host = args[0] if args else wrapper
if not hasattr(host, '_throttle'): if not hasattr(host, "_throttle"):
host._throttle = {} host._throttle = {}
if id(self) not in host._throttle: if id(self) not in host._throttle:
@ -217,7 +233,7 @@ class Throttle:
return throttled_value() return throttled_value()
# Check if method is never called or no_throttle is given # Check if method is never called or no_throttle is given
force = kwargs.pop('no_throttle', False) or not throttle[1] force = kwargs.pop("no_throttle", False) or not throttle[1]
try: try:
if force or utcnow() - throttle[1] > self.min_time: if force or utcnow() - throttle[1] > self.min_time:

View File

@ -6,6 +6,7 @@ good-names=i,j,k,ex,Run,_,fp
[MESSAGES CONTROL] [MESSAGES CONTROL]
# Reasons disabled: # Reasons disabled:
# format - handled by black
# locally-disabled - it spams too much # locally-disabled - it spams too much
# duplicate-code - unavoidable # duplicate-code - unavoidable
# cyclic-import - doesn't test if both import on load # cyclic-import - doesn't test if both import on load
@ -20,6 +21,7 @@ good-names=i,j,k,ex,Run,_,fp
# not-an-iterable - https://github.com/PyCQA/pylint/issues/2311 # not-an-iterable - https://github.com/PyCQA/pylint/issues/2311
# unnecessary-pass - readability for functions which only contain pass # unnecessary-pass - readability for functions which only contain pass
disable= disable=
format,
abstract-class-little-used, abstract-class-little-used,
abstract-method, abstract-method,
cyclic-import, cyclic-import,

3
pyproject.toml Normal file
View File

@ -0,0 +1,3 @@
[tool.black]
target-version = ["py36", "py37", "py38"]
exclude = 'generated'

View File

@ -1,7 +1,10 @@
# linters such as flake8 and pylint should be pinned, as new releases # linters such as flake8 and pylint should be pinned, as new releases
# make new things fail. Manually update these pins when pulling in a # make new things fail. Manually update these pins when pulling in a
# new version # new version
# When updating this file, update .pre-commit-config.yaml too
asynctest==0.13.0 asynctest==0.13.0
black==19.3b0
codecov==2.0.15 codecov==2.0.15
coveralls==1.2.0 coveralls==1.2.0
flake8-docstrings==1.3.0 flake8-docstrings==1.3.0
@ -16,3 +19,4 @@ pytest-sugar==0.9.2
pytest-timeout==1.3.3 pytest-timeout==1.3.3
pytest==5.0.1 pytest==5.0.1
requests_mock==1.6.0 requests_mock==1.6.0
pre-commit==1.17.0

View File

@ -2,7 +2,10 @@
# linters such as flake8 and pylint should be pinned, as new releases # linters such as flake8 and pylint should be pinned, as new releases
# make new things fail. Manually update these pins when pulling in a # make new things fail. Manually update these pins when pulling in a
# new version # new version
# When updating this file, update .pre-commit-config.yaml too
asynctest==0.13.0 asynctest==0.13.0
black==19.3b0
codecov==2.0.15 codecov==2.0.15
coveralls==1.2.0 coveralls==1.2.0
flake8-docstrings==1.3.0 flake8-docstrings==1.3.0
@ -17,6 +20,7 @@ pytest-sugar==0.9.2
pytest-timeout==1.3.3 pytest-timeout==1.3.3
pytest==5.0.1 pytest==5.0.1
requests_mock==1.6.0 requests_mock==1.6.0
pre-commit==1.17.0
# homeassistant.components.homekit # homeassistant.components.homekit

View File

@ -7,4 +7,4 @@ set -e
cd "$(dirname "$0")/.." cd "$(dirname "$0")/.."
echo "Installing test dependencies..." echo "Installing test dependencies..."
python3 -m pip install tox colorlog python3 -m pip install tox colorlog pre-commit

10
script/check_format Executable file
View File

@ -0,0 +1,10 @@
#!/bin/sh
# Format code with black.
cd "$(dirname "$0")/.."
black \
--check \
--fast \
--quiet \
homeassistant tests script

View File

@ -7,4 +7,5 @@ set -e
cd "$(dirname "$0")/.." cd "$(dirname "$0")/.."
script/bootstrap script/bootstrap
pre-commit install
pip3 install -e . pip3 install -e .

View File

@ -21,12 +21,27 @@ norecursedirs = .git testing_config
[flake8] [flake8]
exclude = .venv,.git,.tox,docs,venv,bin,lib,deps,build exclude = .venv,.git,.tox,docs,venv,bin,lib,deps,build
# To work with Black
max-line-length = 88
# E501: line too long
# W503: Line break occurred before a binary operator
# E203: Whitespace before ':'
# D202 No blank lines allowed after function docstring
ignore =
E501,
W503,
E203,
D202
[isort] [isort]
# https://github.com/timothycrosley/isort # https://github.com/timothycrosley/isort
# https://github.com/timothycrosley/isort/wiki/isort-Settings # https://github.com/timothycrosley/isort/wiki/isort-Settings
# splits long import on multiple lines indented by 4 spaces # splits long import on multiple lines indented by 4 spaces
multi_line_output = 4 multi_line_output = 3
include_trailing_comma=True
force_grid_wrap=0
use_parentheses=True
line_length=88
indent = " " indent = " "
# by default isort don't check module indexes # by default isort don't check module indexes
not_skip = __init__.py not_skip = __init__.py
@ -37,4 +52,3 @@ default_section = THIRDPARTY
known_first_party = homeassistant,tests known_first_party = homeassistant,tests
forced_separate = tests forced_separate = tests
combine_as_imports = true combine_as_imports = true
use_parentheses = true