mirror of
https://github.com/home-assistant/core.git
synced 2025-11-09 02:49:40 +00:00
Only create front-end client_id once (#15214)
* Only create frontend client_id once * Check user and client_id before create refresh token * Lint * Follow code review comment * Minor clenaup * Update doc string
This commit is contained in:
committed by
Paulus Schoutsen
parent
c3ad30ec87
commit
11ba7cc8ce
@@ -1,23 +1,22 @@
|
||||
"""Provide an authentication layer for Home Assistant."""
|
||||
import asyncio
|
||||
import binascii
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
import os
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import attr
|
||||
import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
||||
from homeassistant import data_entry_flow, requirements
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.const import CONF_TYPE, CONF_NAME, CONF_ID
|
||||
from homeassistant.util.decorator import Registry
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from homeassistant.util.decorator import Registry
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -337,6 +336,16 @@ class AuthManager:
|
||||
return await self._store.async_create_client(
|
||||
name, redirect_uris, no_secret)
|
||||
|
||||
async def async_get_or_create_client(self, name, *, redirect_uris=None,
|
||||
no_secret=False):
|
||||
"""Find a client, if not exists, create a new one."""
|
||||
for client in await self._store.async_get_clients():
|
||||
if client.name == name:
|
||||
return client
|
||||
|
||||
return await self._store.async_create_client(
|
||||
name, redirect_uris, no_secret)
|
||||
|
||||
async def async_get_client(self, client_id):
|
||||
"""Get a client."""
|
||||
return await self._store.async_get_client(client_id)
|
||||
@@ -380,29 +389,36 @@ class AuthStore:
|
||||
def __init__(self, hass):
|
||||
"""Initialize the auth store."""
|
||||
self.hass = hass
|
||||
self.users = None
|
||||
self.clients = None
|
||||
self._users = None
|
||||
self._clients = None
|
||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||
|
||||
async def credentials_for_provider(self, provider_type, provider_id):
|
||||
"""Return credentials for specific auth provider type and id."""
|
||||
if self.users is None:
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
|
||||
return [
|
||||
credentials
|
||||
for user in self.users.values()
|
||||
for user in self._users.values()
|
||||
for credentials in user.credentials
|
||||
if (credentials.auth_provider_type == provider_type and
|
||||
credentials.auth_provider_id == provider_id)
|
||||
]
|
||||
|
||||
async def async_get_user(self, user_id):
|
||||
"""Retrieve a user."""
|
||||
if self.users is None:
|
||||
async def async_get_users(self):
|
||||
"""Retrieve all users."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
|
||||
return self.users.get(user_id)
|
||||
return list(self._users.values())
|
||||
|
||||
async def async_get_user(self, user_id):
|
||||
"""Retrieve a user."""
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
|
||||
return self._users.get(user_id)
|
||||
|
||||
async def async_get_or_create_user(self, credentials, auth_provider):
|
||||
"""Get or create a new user for given credentials.
|
||||
@@ -410,7 +426,7 @@ class AuthStore:
|
||||
If link_user is passed in, the credentials will be linked to the passed
|
||||
in user if the credentials are new.
|
||||
"""
|
||||
if self.users is None:
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
|
||||
# New credentials, store in user
|
||||
@@ -418,7 +434,7 @@ class AuthStore:
|
||||
info = await auth_provider.async_user_meta_for_credentials(
|
||||
credentials)
|
||||
# Make owner and activate user if it's the first user.
|
||||
if self.users:
|
||||
if self._users:
|
||||
is_owner = False
|
||||
is_active = False
|
||||
else:
|
||||
@@ -430,11 +446,11 @@ class AuthStore:
|
||||
is_active=is_active,
|
||||
name=info.get('name'),
|
||||
)
|
||||
self.users[new_user.id] = new_user
|
||||
self._users[new_user.id] = new_user
|
||||
await self.async_link_user(new_user, credentials)
|
||||
return new_user
|
||||
|
||||
for user in self.users.values():
|
||||
for user in self._users.values():
|
||||
for creds in user.credentials:
|
||||
if (creds.auth_provider_type == credentials.auth_provider_type
|
||||
and creds.auth_provider_id ==
|
||||
@@ -451,11 +467,19 @@ class AuthStore:
|
||||
|
||||
async def async_remove_user(self, user):
|
||||
"""Remove a user."""
|
||||
self.users.pop(user.id)
|
||||
self._users.pop(user.id)
|
||||
await self.async_save()
|
||||
|
||||
async def async_create_refresh_token(self, user, client_id):
|
||||
"""Create a new token for a user."""
|
||||
local_user = await self.async_get_user(user.id)
|
||||
if local_user is None:
|
||||
raise ValueError('Invalid user')
|
||||
|
||||
local_client = await self.async_get_client(client_id)
|
||||
if local_client is None:
|
||||
raise ValueError('Invalid client_id')
|
||||
|
||||
refresh_token = RefreshToken(user, client_id)
|
||||
user.refresh_tokens[refresh_token.token] = refresh_token
|
||||
await self.async_save()
|
||||
@@ -463,10 +487,10 @@ class AuthStore:
|
||||
|
||||
async def async_get_refresh_token(self, token):
|
||||
"""Get refresh token by token."""
|
||||
if self.users is None:
|
||||
if self._users is None:
|
||||
await self.async_load()
|
||||
|
||||
for user in self.users.values():
|
||||
for user in self._users.values():
|
||||
refresh_token = user.refresh_tokens.get(token)
|
||||
if refresh_token is not None:
|
||||
return refresh_token
|
||||
@@ -475,7 +499,7 @@ class AuthStore:
|
||||
|
||||
async def async_create_client(self, name, redirect_uris, no_secret):
|
||||
"""Create a new client."""
|
||||
if self.clients is None:
|
||||
if self._clients is None:
|
||||
await self.async_load()
|
||||
|
||||
kwargs = {
|
||||
@@ -487,16 +511,23 @@ class AuthStore:
|
||||
kwargs['secret'] = None
|
||||
|
||||
client = Client(**kwargs)
|
||||
self.clients[client.id] = client
|
||||
self._clients[client.id] = client
|
||||
await self.async_save()
|
||||
return client
|
||||
|
||||
async def async_get_client(self, client_id):
|
||||
"""Get a client."""
|
||||
if self.clients is None:
|
||||
async def async_get_clients(self):
|
||||
"""Return all clients."""
|
||||
if self._clients is None:
|
||||
await self.async_load()
|
||||
|
||||
return self.clients.get(client_id)
|
||||
return list(self._clients.values())
|
||||
|
||||
async def async_get_client(self, client_id):
|
||||
"""Get a client."""
|
||||
if self._clients is None:
|
||||
await self.async_load()
|
||||
|
||||
return self._clients.get(client_id)
|
||||
|
||||
async def async_load(self):
|
||||
"""Load the users."""
|
||||
@@ -504,12 +535,12 @@ class AuthStore:
|
||||
|
||||
# Make sure that we're not overriding data if 2 loads happened at the
|
||||
# same time
|
||||
if self.users is not None:
|
||||
if self._users is not None:
|
||||
return
|
||||
|
||||
if data is None:
|
||||
self.users = {}
|
||||
self.clients = {}
|
||||
self._users = {}
|
||||
self._clients = {}
|
||||
return
|
||||
|
||||
users = {
|
||||
@@ -553,8 +584,8 @@ class AuthStore:
|
||||
cl_dict['id']: Client(**cl_dict) for cl_dict in data['clients']
|
||||
}
|
||||
|
||||
self.users = users
|
||||
self.clients = clients
|
||||
self._users = users
|
||||
self._clients = clients
|
||||
|
||||
async def async_save(self):
|
||||
"""Save users."""
|
||||
@@ -565,7 +596,7 @@ class AuthStore:
|
||||
'is_active': user.is_active,
|
||||
'name': user.name,
|
||||
}
|
||||
for user in self.users.values()
|
||||
for user in self._users.values()
|
||||
]
|
||||
|
||||
credentials = [
|
||||
@@ -576,7 +607,7 @@ class AuthStore:
|
||||
'auth_provider_id': credential.auth_provider_id,
|
||||
'data': credential.data,
|
||||
}
|
||||
for user in self.users.values()
|
||||
for user in self._users.values()
|
||||
for credential in user.credentials
|
||||
]
|
||||
|
||||
@@ -590,7 +621,7 @@ class AuthStore:
|
||||
refresh_token.access_token_expiration.total_seconds(),
|
||||
'token': refresh_token.token,
|
||||
}
|
||||
for user in self.users.values()
|
||||
for user in self._users.values()
|
||||
for refresh_token in user.refresh_tokens.values()
|
||||
]
|
||||
|
||||
@@ -601,7 +632,7 @@ class AuthStore:
|
||||
'created_at': access_token.created_at.isoformat(),
|
||||
'token': access_token.token,
|
||||
}
|
||||
for user in self.users.values()
|
||||
for user in self._users.values()
|
||||
for refresh_token in user.refresh_tokens.values()
|
||||
for access_token in refresh_token.access_tokens
|
||||
]
|
||||
@@ -613,7 +644,7 @@ class AuthStore:
|
||||
'secret': client.secret,
|
||||
'redirect_uris': client.redirect_uris,
|
||||
}
|
||||
for client in self.clients.values()
|
||||
for client in self._clients.values()
|
||||
]
|
||||
|
||||
data = {
|
||||
|
||||
Reference in New Issue
Block a user