Handle Plex certificate updates (#33230)

* Handle Plex certificate updates

* Use exception in place

* Add test
This commit is contained in:
jjlawren 2020-03-28 23:02:29 -05:00 committed by GitHub
parent 3c2df7f8f2
commit 312af53935
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 97 additions and 5 deletions

View File

@ -46,6 +46,7 @@ from .const import (
SERVERS, SERVERS,
WEBSOCKETS, WEBSOCKETS,
) )
from .errors import ShouldUpdateConfigEntry
from .server import PlexServer from .server import PlexServer
MEDIA_PLAYER_SCHEMA = vol.All( MEDIA_PLAYER_SCHEMA = vol.All(
@ -129,9 +130,20 @@ async def async_setup_entry(hass, entry):
) )
hass.config_entries.async_update_entry(entry, options=options) hass.config_entries.async_update_entry(entry, options=options)
plex_server = PlexServer(hass, server_config, entry.options) plex_server = PlexServer(
hass, server_config, entry.data[CONF_SERVER_IDENTIFIER], entry.options
)
try: try:
await hass.async_add_executor_job(plex_server.connect) await hass.async_add_executor_job(plex_server.connect)
except ShouldUpdateConfigEntry:
new_server_data = {
**entry.data[PLEX_SERVER_CONFIG],
CONF_URL: plex_server.url_in_use,
CONF_SERVER: plex_server.friendly_name,
}
hass.config_entries.async_update_entry(
entry, data={**entry.data, PLEX_SERVER_CONFIG: new_server_data}
)
except requests.exceptions.ConnectionError as error: except requests.exceptions.ConnectionError as error:
_LOGGER.error( _LOGGER.error(
"Plex server (%s) could not be reached: [%s]", "Plex server (%s) could not be reached: [%s]",

View File

@ -12,3 +12,7 @@ class NoServersFound(PlexException):
class ServerNotSpecified(PlexException): class ServerNotSpecified(PlexException):
"""Multiple servers linked to account without choice provided.""" """Multiple servers linked to account without choice provided."""
class ShouldUpdateConfigEntry(PlexException):
"""Config entry data is out of date and should be updated."""

View File

@ -1,5 +1,7 @@
"""Shared class to maintain Plex server instances.""" """Shared class to maintain Plex server instances."""
import logging import logging
import ssl
from urllib.parse import urlparse
import plexapi.myplex import plexapi.myplex
import plexapi.playqueue import plexapi.playqueue
@ -26,7 +28,7 @@ from .const import (
X_PLEX_PRODUCT, X_PLEX_PRODUCT,
X_PLEX_VERSION, X_PLEX_VERSION,
) )
from .errors import NoServersFound, ServerNotSpecified from .errors import NoServersFound, ServerNotSpecified, ShouldUpdateConfigEntry
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -40,7 +42,7 @@ plexapi.X_PLEX_VERSION = X_PLEX_VERSION
class PlexServer: class PlexServer:
"""Manages a single Plex server connection.""" """Manages a single Plex server connection."""
def __init__(self, hass, server_config, options=None): def __init__(self, hass, server_config, known_server_id=None, options=None):
"""Initialize a Plex server instance.""" """Initialize a Plex server instance."""
self._hass = hass self._hass = hass
self._plex_server = None self._plex_server = None
@ -50,6 +52,7 @@ class PlexServer:
self._token = server_config.get(CONF_TOKEN) self._token = server_config.get(CONF_TOKEN)
self._server_name = server_config.get(CONF_SERVER) self._server_name = server_config.get(CONF_SERVER)
self._verify_ssl = server_config.get(CONF_VERIFY_SSL, DEFAULT_VERIFY_SSL) self._verify_ssl = server_config.get(CONF_VERIFY_SSL, DEFAULT_VERIFY_SSL)
self._server_id = known_server_id
self.options = options self.options = options
self.server_choice = None self.server_choice = None
self._accounts = [] self._accounts = []
@ -64,6 +67,7 @@ class PlexServer:
def connect(self): def connect(self):
"""Connect to a Plex server directly, obtaining direct URL if necessary.""" """Connect to a Plex server directly, obtaining direct URL if necessary."""
config_entry_update_needed = False
def _connect_with_token(): def _connect_with_token():
account = plexapi.myplex.MyPlexAccount(token=self._token) account = plexapi.myplex.MyPlexAccount(token=self._token)
@ -92,8 +96,33 @@ class PlexServer:
self._url, self._token, session self._url, self._token, session
) )
def _update_plexdirect_hostname():
account = plexapi.myplex.MyPlexAccount(token=self._token)
matching_server = [
x.name
for x in account.resources()
if x.clientIdentifier == self._server_id
][0]
self._plex_server = account.resource(matching_server).connect(timeout=10)
if self._url: if self._url:
_connect_with_url() try:
_connect_with_url()
except requests.exceptions.SSLError as error:
while error and not isinstance(error, ssl.SSLCertVerificationError):
error = error.__context__ # pylint: disable=no-member
if isinstance(error, ssl.SSLCertVerificationError):
domain = urlparse(self._url).netloc.split(":")[0]
if domain.endswith("plex.direct") and error.args[0].startswith(
f"hostname '{domain}' doesn't match"
):
_LOGGER.warning(
"Plex SSL certificate's hostname changed, updating."
)
_update_plexdirect_hostname()
config_entry_update_needed = True
else:
raise
else: else:
_connect_with_token() _connect_with_token()
@ -113,6 +142,9 @@ class PlexServer:
self._version = self._plex_server.version self._version = self._plex_server.version
if config_entry_update_needed:
raise ShouldUpdateConfigEntry
def refresh_entity(self, machine_identifier, device, session): def refresh_entity(self, machine_identifier, device, session):
"""Forward refresh dispatch to media_player.""" """Forward refresh dispatch to media_player."""
unique_id = f"{self.machine_identifier}:{machine_identifier}" unique_id = f"{self.machine_identifier}:{machine_identifier}"

View File

@ -1,6 +1,7 @@
"""Tests for Plex setup.""" """Tests for Plex setup."""
import copy import copy
from datetime import timedelta from datetime import timedelta
import ssl
from asynctest import patch from asynctest import patch
import plexapi import plexapi
@ -19,6 +20,7 @@ from homeassistant.const import (
CONF_PORT, CONF_PORT,
CONF_SSL, CONF_SSL,
CONF_TOKEN, CONF_TOKEN,
CONF_URL,
CONF_VERIFY_SSL, CONF_VERIFY_SSL,
) )
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
@ -26,7 +28,7 @@ from homeassistant.setup import async_setup_component
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .const import DEFAULT_DATA, DEFAULT_OPTIONS, MOCK_SERVERS, MOCK_TOKEN from .const import DEFAULT_DATA, DEFAULT_OPTIONS, MOCK_SERVERS, MOCK_TOKEN
from .mock_classes import MockPlexServer from .mock_classes import MockPlexAccount, MockPlexServer
from tests.common import MockConfigEntry, async_fire_time_changed from tests.common import MockConfigEntry, async_fire_time_changed
@ -300,3 +302,45 @@ async def test_setup_with_photo_session(hass):
sensor = hass.states.get("sensor.plex_plex_server_1") sensor = hass.states.get("sensor.plex_plex_server_1")
assert sensor.state == str(len(mock_plex_server.accounts)) assert sensor.state == str(len(mock_plex_server.accounts))
async def test_setup_when_certificate_changed(hass):
"""Test setup component when the Plex certificate has changed."""
old_domain = "1-2-3-4.1234567890abcdef1234567890abcdef.plex.direct"
old_url = f"https://{old_domain}:32400"
OLD_HOSTNAME_DATA = copy.deepcopy(DEFAULT_DATA)
OLD_HOSTNAME_DATA[const.PLEX_SERVER_CONFIG][CONF_URL] = old_url
class WrongCertHostnameException(requests.exceptions.SSLError):
"""Mock the exception showing a mismatched hostname."""
def __init__(self):
self.__context__ = ssl.SSLCertVerificationError(
f"hostname '{old_domain}' doesn't match"
)
old_entry = MockConfigEntry(
domain=const.DOMAIN,
data=OLD_HOSTNAME_DATA,
options=DEFAULT_OPTIONS,
unique_id=DEFAULT_DATA["server_id"],
)
new_entry = MockConfigEntry(domain=const.DOMAIN, data=DEFAULT_DATA)
with patch(
"plexapi.server.PlexServer", side_effect=WrongCertHostnameException
), patch("plexapi.myplex.MyPlexAccount", return_value=MockPlexAccount()):
old_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(old_entry.entry_id)
await hass.async_block_till_done()
assert len(hass.config_entries.async_entries(const.DOMAIN)) == 1
assert old_entry.state == ENTRY_STATE_LOADED
assert (
old_entry.data[const.PLEX_SERVER_CONFIG][CONF_URL]
== new_entry.data[const.PLEX_SERVER_CONFIG][CONF_URL]
)