Make it easier for auth to consume newer formats (#17127)

This commit is contained in:
Paulus Schoutsen 2018-10-04 10:41:13 +02:00 committed by GitHub
parent cc1891ef2b
commit a559c06d6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 30 additions and 18 deletions

View File

@ -1,9 +1,9 @@
"""Storage for auth models.""" """Storage for auth models."""
from collections import OrderedDict from collections import OrderedDict
from datetime import timedelta from datetime import timedelta
import hmac
from logging import getLogger from logging import getLogger
from typing import Any, Dict, List, Optional # noqa: F401 from typing import Any, Dict, List, Optional # noqa: F401
import hmac
from homeassistant.auth.const import ACCESS_TOKEN_EXPIRATION from homeassistant.auth.const import ACCESS_TOKEN_EXPIRATION
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
@ -214,14 +214,24 @@ class AuthStore:
if self._users is not None: if self._users is not None:
return return
users = OrderedDict() # type: Dict[str, models.User]
if data is None: if data is None:
self._users = users self._set_defaults()
return return
users = OrderedDict() # type: Dict[str, models.User]
# When creating objects we mention each attribute explicetely. This
# prevents crashing if user rolls back HA version after a new property
# was added.
for user_dict in data['users']: for user_dict in data['users']:
users[user_dict['id']] = models.User(**user_dict) users[user_dict['id']] = models.User(
name=user_dict['name'],
id=user_dict['id'],
is_owner=user_dict['is_owner'],
is_active=user_dict['is_active'],
system_generated=user_dict['system_generated'],
)
for cred_dict in data['credentials']: for cred_dict in data['credentials']:
users[cred_dict['user_id']].credentials.append(models.Credentials( users[cred_dict['user_id']].credentials.append(models.Credentials(
@ -341,3 +351,7 @@ class AuthStore:
'credentials': credentials, 'credentials': credentials,
'refresh_tokens': refresh_tokens, 'refresh_tokens': refresh_tokens,
} }
def _set_defaults(self) -> None:
"""Set default values for auth store."""
self._users = OrderedDict() # type: Dict[str, models.User]

View File

@ -19,19 +19,19 @@ class User:
"""A user.""" """A user."""
name = attr.ib(type=str) # type: Optional[str] name = attr.ib(type=str) # type: Optional[str]
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex)) id = attr.ib(type=str, factory=lambda: uuid.uuid4().hex)
is_owner = attr.ib(type=bool, default=False) is_owner = attr.ib(type=bool, default=False)
is_active = attr.ib(type=bool, default=False) is_active = attr.ib(type=bool, default=False)
system_generated = attr.ib(type=bool, default=False) system_generated = attr.ib(type=bool, default=False)
# List of credentials of a user. # List of credentials of a user.
credentials = attr.ib( credentials = attr.ib(
type=list, default=attr.Factory(list), cmp=False type=list, factory=list, cmp=False
) # type: List[Credentials] ) # type: List[Credentials]
# Tokens associated with a user. # Tokens associated with a user.
refresh_tokens = attr.ib( refresh_tokens = attr.ib(
type=dict, default=attr.Factory(dict), cmp=False type=dict, factory=dict, cmp=False
) # type: Dict[str, RefreshToken] ) # type: Dict[str, RefreshToken]
@ -48,12 +48,10 @@ class RefreshToken:
validator=attr.validators.in_(( validator=attr.validators.in_((
TOKEN_TYPE_NORMAL, TOKEN_TYPE_SYSTEM, TOKEN_TYPE_NORMAL, TOKEN_TYPE_SYSTEM,
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN))) TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN)))
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex)) id = attr.ib(type=str, factory=lambda: uuid.uuid4().hex)
created_at = attr.ib(type=datetime, default=attr.Factory(dt_util.utcnow)) created_at = attr.ib(type=datetime, factory=dt_util.utcnow)
token = attr.ib(type=str, token = attr.ib(type=str, factory=lambda: generate_secret(64))
default=attr.Factory(lambda: generate_secret(64))) jwt_key = attr.ib(type=str, factory=lambda: generate_secret(64))
jwt_key = attr.ib(type=str,
default=attr.Factory(lambda: generate_secret(64)))
last_used_at = attr.ib(type=Optional[datetime], default=None) last_used_at = attr.ib(type=Optional[datetime], default=None)
last_used_ip = attr.ib(type=Optional[str], default=None) last_used_ip = attr.ib(type=Optional[str], default=None)
@ -69,7 +67,7 @@ class Credentials:
# Allow the auth provider to store data to represent their auth. # Allow the auth provider to store data to represent their auth.
data = attr.ib(type=dict) data = attr.ib(type=dict)
id = attr.ib(type=str, default=attr.Factory(lambda: uuid.uuid4().hex)) id = attr.ib(type=str, factory=lambda: uuid.uuid4().hex)
is_new = attr.ib(type=bool, default=True) is_new = attr.ib(type=bool, default=True)

View File

@ -2,7 +2,7 @@
import asyncio import asyncio
import logging import logging
import os import os
from typing import Dict, Optional, Callable from typing import Dict, Optional, Callable, Any
from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import callback from homeassistant.core import callback
@ -63,7 +63,7 @@ class Store:
"""Return the config path.""" """Return the config path."""
return self.hass.config.path(STORAGE_DIR, self.key) return self.hass.config.path(STORAGE_DIR, self.key)
async def async_load(self): async def async_load(self) -> Optional[Dict[str, Any]]:
"""Load data. """Load data.
If the expected version does not match the given version, the migrate If the expected version does not match the given version, the migrate

View File

@ -392,7 +392,7 @@ def ensure_auth_manager_loaded(auth_mgr):
"""Ensure an auth manager is considered loaded.""" """Ensure an auth manager is considered loaded."""
store = auth_mgr._store store = auth_mgr._store
if store._users is None: if store._users is None:
store._users = OrderedDict() store._set_defaults()
class MockModule: class MockModule: