Update aiohttp to 2.3.1 (#10139)

* Update aiohttp to 2.3.1

* set timeout 10sec

* fix freeze with new middleware handling

* Convert middleware auth

* Convert mittleware ipban

* convert middleware static

* fix lint

* Update ban.py

* Update auth.py

* fix lint

* Fix tests
This commit is contained in:
Pascal Vizeli 2017-11-06 03:42:31 +01:00 committed by Paulus Schoutsen
parent 39de557c4c
commit a9a3e24bde
8 changed files with 67 additions and 92 deletions

View File

@ -182,8 +182,6 @@ class HomeAssistantWSGI(object):
use_x_forwarded_for, trusted_networks, use_x_forwarded_for, trusted_networks,
login_threshold, is_ban_enabled): login_threshold, is_ban_enabled):
"""Initialize the WSGI Home Assistant server.""" """Initialize the WSGI Home Assistant server."""
import aiohttp_cors
middlewares = [auth_middleware, staticresource_middleware] middlewares = [auth_middleware, staticresource_middleware]
if is_ban_enabled: if is_ban_enabled:
@ -206,6 +204,8 @@ class HomeAssistantWSGI(object):
self.server = None self.server = None
if cors_origins: if cors_origins:
import aiohttp_cors
self.cors = aiohttp_cors.setup(self.app, defaults={ self.cors = aiohttp_cors.setup(self.app, defaults={
host: aiohttp_cors.ResourceOptions( host: aiohttp_cors.ResourceOptions(
allow_headers=ALLOWED_CORS_HEADERS, allow_headers=ALLOWED_CORS_HEADERS,
@ -335,7 +335,9 @@ class HomeAssistantWSGI(object):
_LOGGER.error("Failed to create HTTP server at port %d: %s", _LOGGER.error("Failed to create HTTP server at port %d: %s",
self.server_port, error) self.server_port, error)
self.app._frozen = False # pylint: disable=protected-access # pylint: disable=protected-access
self.app._middlewares = tuple(self.app._prepare_middleware())
self.app._frozen = False
@asyncio.coroutine @asyncio.coroutine
def stop(self): def stop(self):
@ -345,7 +347,7 @@ class HomeAssistantWSGI(object):
yield from self.server.wait_closed() yield from self.server.wait_closed()
yield from self.app.shutdown() yield from self.app.shutdown()
if self._handler: if self._handler:
yield from self._handler.finish_connections(60.0) yield from self._handler.shutdown(10)
yield from self.app.cleanup() yield from self.app.cleanup()

View File

@ -5,6 +5,7 @@ import hmac
import logging import logging
from aiohttp import hdrs from aiohttp import hdrs
from aiohttp.web import middleware
from homeassistant.const import HTTP_HEADER_HA_AUTH from homeassistant.const import HTTP_HEADER_HA_AUTH
from .util import get_real_ip from .util import get_real_ip
@ -15,23 +16,16 @@ DATA_API_PASSWORD = 'api_password'
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@middleware
@asyncio.coroutine @asyncio.coroutine
def auth_middleware(app, handler): def auth_middleware(request, handler):
"""Authenticate as middleware.""" """Authenticate as middleware."""
# If no password set, just always set authenticated=True # If no password set, just always set authenticated=True
if app['hass'].http.api_password is None: if request.app['hass'].http.api_password is None:
@asyncio.coroutine
def no_auth_middleware_handler(request):
"""Auth middleware to approve all requests."""
request[KEY_AUTHENTICATED] = True request[KEY_AUTHENTICATED] = True
return handler(request) return handler(request)
return no_auth_middleware_handler # Check authentication
@asyncio.coroutine
def auth_middleware_handler(request):
"""Auth middleware to check authentication."""
# Auth code verbose on purpose
authenticated = False authenticated = False
if (HTTP_HEADER_HA_AUTH in request.headers and if (HTTP_HEADER_HA_AUTH in request.headers and
@ -52,11 +46,8 @@ def auth_middleware(app, handler):
authenticated = True authenticated = True
request[KEY_AUTHENTICATED] = authenticated request[KEY_AUTHENTICATED] = authenticated
return handler(request) return handler(request)
return auth_middleware_handler
def is_trusted_ip(request): def is_trusted_ip(request):
"""Test if request is from a trusted ip.""" """Test if request is from a trusted ip."""

View File

@ -6,6 +6,7 @@ from ipaddress import ip_address
import logging import logging
import os import os
from aiohttp.web import middleware
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
import voluptuous as vol import voluptuous as vol
@ -32,20 +33,19 @@ SCHEMA_IP_BAN_ENTRY = vol.Schema({
}) })
@middleware
@asyncio.coroutine @asyncio.coroutine
def ban_middleware(app, handler): def ban_middleware(request, handler):
"""IP Ban middleware.""" """IP Ban middleware."""
if not app[KEY_BANS_ENABLED]: if not request.app[KEY_BANS_ENABLED]:
return handler return (yield from handler(request))
if KEY_BANNED_IPS not in app: if KEY_BANNED_IPS not in request.app:
hass = app['hass'] hass = request.app['hass']
app[KEY_BANNED_IPS] = yield from hass.async_add_job( request.app[KEY_BANNED_IPS] = yield from hass.async_add_job(
load_ip_bans_config, hass.config.path(IP_BANS_FILE)) load_ip_bans_config, hass.config.path(IP_BANS_FILE))
@asyncio.coroutine # Verify if IP is not banned
def ban_middleware_handler(request):
"""Verify if IP is not banned."""
ip_address_ = get_real_ip(request) ip_address_ = get_real_ip(request)
is_banned = any(ip_ban.ip_address == ip_address_ is_banned = any(ip_ban.ip_address == ip_address_
@ -60,8 +60,6 @@ def ban_middleware(app, handler):
yield from process_wrong_login(request) yield from process_wrong_login(request)
raise raise
return ban_middleware_handler
@asyncio.coroutine @asyncio.coroutine
def process_wrong_login(request): def process_wrong_login(request):

View File

@ -3,7 +3,7 @@ import asyncio
import re import re
from aiohttp import hdrs from aiohttp import hdrs
from aiohttp.web import FileResponse from aiohttp.web import FileResponse, middleware
from aiohttp.web_exceptions import HTTPNotFound from aiohttp.web_exceptions import HTTPNotFound
from aiohttp.web_urldispatcher import StaticResource from aiohttp.web_urldispatcher import StaticResource
from yarl import unquote from yarl import unquote
@ -61,12 +61,10 @@ class CachingFileResponse(FileResponse):
self._sendfile = sendfile self._sendfile = sendfile
@middleware
@asyncio.coroutine @asyncio.coroutine
def staticresource_middleware(app, handler): def staticresource_middleware(request, handler):
"""Middleware to strip out fingerprint from fingerprinted assets.""" """Middleware to strip out fingerprint from fingerprinted assets."""
@asyncio.coroutine
def static_middleware_handler(request):
"""Strip out fingerprints from resource names."""
if not request.path.startswith('/static/'): if not request.path.startswith('/static/'):
return handler(request) return handler(request)
@ -77,5 +75,3 @@ def staticresource_middleware(app, handler):
'{}.{}'.format(*fingerprinted.groups()) '{}.{}'.format(*fingerprinted.groups())
return handler(request) return handler(request)
return static_middleware_handler

View File

@ -5,7 +5,7 @@ pip>=8.0.3
jinja2>=2.9.6 jinja2>=2.9.6
voluptuous==0.10.5 voluptuous==0.10.5
typing>=3,<4 typing>=3,<4
aiohttp==2.2.5 aiohttp==2.3.1
async_timeout==2.0.0 async_timeout==2.0.0
chardet==3.0.4 chardet==3.0.4
astral==1.4 astral==1.4

View File

@ -6,7 +6,7 @@ pip>=8.0.3
jinja2>=2.9.6 jinja2>=2.9.6
voluptuous==0.10.5 voluptuous==0.10.5
typing>=3,<4 typing>=3,<4
aiohttp==2.2.5 aiohttp==2.3.1
async_timeout==2.0.0 async_timeout==2.0.0
chardet==3.0.4 chardet==3.0.4
astral==1.4 astral==1.4

View File

@ -53,7 +53,7 @@ REQUIRES = [
'jinja2>=2.9.6', 'jinja2>=2.9.6',
'voluptuous==0.10.5', 'voluptuous==0.10.5',
'typing>=3,<4', 'typing>=3,<4',
'aiohttp==2.2.5', 'aiohttp==2.3.1',
'async_timeout==2.0.0', 'async_timeout==2.0.0',
'chardet==3.0.4', 'chardet==3.0.4',
'astral==1.4', 'astral==1.4',

View File

@ -143,22 +143,10 @@ def test_registering_view_while_running(hass, test_client):
} }
) )
yield from setup.async_setup_component(hass, 'api')
yield from hass.async_start() yield from hass.async_start()
# This raises a RuntimeError if app is frozen
yield from hass.async_block_till_done()
hass.http.register_view(TestView) hass.http.register_view(TestView)
client = yield from test_client(hass.http.app)
resp = yield from client.get('/hello')
assert resp.status == 200
text = yield from resp.text()
assert text == 'hello'
@asyncio.coroutine @asyncio.coroutine
def test_api_base_url_with_domain(hass): def test_api_base_url_with_domain(hass):