From c556b619b79341be699b24a194f0416cbc529c7a Mon Sep 17 00:00:00 2001 From: Lev Aronsky Date: Tue, 23 May 2017 19:55:01 +0300 Subject: [PATCH] Asuswrt continuous ssh (#7728) * Make ssh and telnet connections continuous in asuswrt * Refactored SSH and Telnet connections into respective classes. * Fixed several copy-paste typos and errors. * More typos fixed. * Small changes to arguments, to pass automated tests. * Removed unsupported named arguments. * Fixed a couple of mistakes in Telnet, and other lint errors. * Added Telnet tests, and added lint exceptions. * Removed comments from tests, as they irritated the hound. --- .../components/device_tracker/asuswrt.py | 301 ++++++++++++------ .../components/device_tracker/test_asuswrt.py | 82 ++++- 2 files changed, 274 insertions(+), 109 deletions(-) diff --git a/homeassistant/components/device_tracker/asuswrt.py b/homeassistant/components/device_tracker/asuswrt.py index a0405b0b690..cc50ab44e54 100644 --- a/homeassistant/components/device_tracker/asuswrt.py +++ b/homeassistant/components/device_tracker/asuswrt.py @@ -118,25 +118,29 @@ class AsusWrtDeviceScanner(DeviceScanner): self.protocol = config[CONF_PROTOCOL] self.mode = config[CONF_MODE] self.port = config[CONF_PORT] - self.ssh_args = {} if self.protocol == 'ssh': - - self.ssh_args['port'] = self.port - if self.ssh_key: - self.ssh_args['ssh_key'] = self.ssh_key - elif self.password: - self.ssh_args['password'] = self.password - else: + if not (self.ssh_key or self.password): _LOGGER.error("No password or private key specified") self.success_init = False return + + self.connection = SshConnection(self.host, self.port, + self.username, + self.password, + self.ssh_key, + self.mode == "ap") else: if not self.password: _LOGGER.error("No password specified") self.success_init = False return + self.connection = TelnetConnection(self.host, self.port, + self.username, + self.password, + self.mode == "ap") + self.lock = threading.Lock() self.last_results = {} @@ -182,105 +186,9 @@ class AsusWrtDeviceScanner(DeviceScanner): self.last_results = active_clients return True - def ssh_connection(self): - """Retrieve data from ASUSWRT via the ssh protocol.""" - from pexpect import pxssh, exceptions - - ssh = pxssh.pxssh() - try: - ssh.login(self.host, self.username, **self.ssh_args) - except exceptions.EOF as err: - _LOGGER.error("Connection refused. SSH enabled?") - return None - except pxssh.ExceptionPxssh as err: - _LOGGER.error("Unable to connect via SSH: %s", str(err)) - return None - - try: - ssh.sendline(_IP_NEIGH_CMD) - ssh.prompt() - neighbors = ssh.before.split(b'\n')[1:-1] - if self.mode == 'ap': - ssh.sendline(_ARP_CMD) - ssh.prompt() - arp_result = ssh.before.split(b'\n')[1:-1] - ssh.sendline(_WL_CMD) - ssh.prompt() - leases_result = ssh.before.split(b'\n')[1:-1] - ssh.sendline(_NVRAM_CMD) - ssh.prompt() - nvram_result = ssh.before.split(b'\n')[1].split(b'<')[1:] - else: - arp_result = [''] - nvram_result = [''] - ssh.sendline(_LEASES_CMD) - ssh.prompt() - leases_result = ssh.before.split(b'\n')[1:-1] - ssh.logout() - return AsusWrtResult(neighbors, leases_result, arp_result, - nvram_result) - except pxssh.ExceptionPxssh as exc: - _LOGGER.error("Unexpected response from router: %s", exc) - return None - - def telnet_connection(self): - """Retrieve data from ASUSWRT via the telnet protocol.""" - try: - telnet = telnetlib.Telnet(self.host) - telnet.read_until(b'login: ') - telnet.write((self.username + '\n').encode('ascii')) - telnet.read_until(b'Password: ') - telnet.write((self.password + '\n').encode('ascii')) - prompt_string = telnet.read_until(b'#').split(b'\n')[-1] - telnet.write('{}\n'.format(_IP_NEIGH_CMD).encode('ascii')) - neighbors = telnet.read_until(prompt_string).split(b'\n')[1:-1] - if self.mode == 'ap': - telnet.write('{}\n'.format(_ARP_CMD).encode('ascii')) - arp_result = (telnet.read_until(prompt_string). - split(b'\n')[1:-1]) - telnet.write('{}\n'.format(_WL_CMD).encode('ascii')) - leases_result = (telnet.read_until(prompt_string). - split(b'\n')[1:-1]) - telnet.write('{}\n'.format(_NVRAM_CMD).encode('ascii')) - nvram_result = (telnet.read_until(prompt_string). - split(b'\n')[1].split(b'<')[1:]) - else: - arp_result = [''] - nvram_result = [''] - telnet.write('{}\n'.format(_LEASES_CMD).encode('ascii')) - leases_result = (telnet.read_until(prompt_string). - split(b'\n')[1:-1]) - telnet.write('exit\n'.encode('ascii')) - return AsusWrtResult(neighbors, leases_result, arp_result, - nvram_result) - except EOFError: - _LOGGER.error("Unexpected response from router") - return None - except ConnectionRefusedError: - _LOGGER.error("Connection refused by router. Telnet enabled?") - return None - except socket.gaierror as exc: - _LOGGER.error("Socket exception: %s", exc) - return None - except OSError as exc: - _LOGGER.error("OSError: %s", exc) - return None - def get_asuswrt_data(self): """Retrieve data from ASUSWRT and return parsed result.""" - if self.protocol == 'ssh': - result = self.ssh_connection() - elif self.protocol == 'telnet': - result = self.telnet_connection() - else: - # autodetect protocol - result = self.ssh_connection() - if result: - self.protocol = 'ssh' - else: - result = self.telnet_connection() - if result: - self.protocol = 'telnet' + result = self.connection.get_result() if not result: return {} @@ -363,3 +271,186 @@ class AsusWrtDeviceScanner(DeviceScanner): if match.group('ip') in devices: devices[match.group('ip')]['status'] = match.group('status') return devices + + +class _Connection: + def __init__(self): + self._connected = False + + @property + def connected(self): + """Return connection state.""" + return self._connected + + def connect(self): + """Mark currenct connection state as connected.""" + self._connected = True + + def disconnect(self): + """Mark current connection state as disconnected.""" + self._connected = False + + +class SshConnection(_Connection): + """Maintains an SSH connection to an ASUS-WRT router.""" + + def __init__(self, host, port, username, password, ssh_key, ap): + """Initialize the SSH connection properties.""" + from pexpect import pxssh + + super(SshConnection, self).__init__() + + self._ssh = pxssh.pxssh() + self._host = host + self._port = port + self._username = username + self._password = password + self._ssh_key = ssh_key + self._ap = ap + + def get_result(self): + """Retrieve a single AsusWrtResult through an SSH connection. + + Connect to the SSH server if not currently connected, otherwise + use the existing connection. + """ + from pexpect import pxssh, exceptions + + try: + if not self.connected: + self.connect() + self._ssh.sendline(_IP_NEIGH_CMD) + self._ssh.prompt() + neighbors = self._ssh.before.split(b'\n')[1:-1] + if self._ap: + self._ssh.sendline(_ARP_CMD) + self._ssh.prompt() + arp_result = self._ssh.before.split(b'\n')[1:-1] + self._ssh.sendline(_WL_CMD) + self._ssh.prompt() + leases_result = self._ssh.before.split(b'\n')[1:-1] + self._ssh.sendline(_NVRAM_CMD) + self._ssh.prompt() + nvram_result = self._ssh.before.split(b'\n')[1].split(b'<')[1:] + else: + arp_result = [''] + nvram_result = [''] + self._ssh.sendline(_LEASES_CMD) + self._ssh.prompt() + leases_result = self._ssh.before.split(b'\n')[1:-1] + return AsusWrtResult(neighbors, leases_result, arp_result, + nvram_result) + except exceptions.EOF as err: + _LOGGER.error("Connection refused. SSH enabled?") + self.disconnect() + return None + except pxssh.ExceptionPxssh as err: + _LOGGER.error("Unexpected SSH error: %s", str(err)) + self.disconnect() + return None + + def connect(self): + """Connect to the ASUS-WRT SSH server.""" + if self._ssh_key: + self._ssh.login(self._host, self._username, + ssh_key=self._ssh_key, port=self._port) + else: + self._ssh.login(self._host, self._username, + password=self._password, port=self._port) + + super(SshConnection, self).connect() + + def disconnect(self): \ + # pylint: disable=broad-except + """Disconnect the current SSH connection.""" + try: + self._ssh.logout() + except Exception: + pass + + super(SshConnection, self).disconnect() + + +class TelnetConnection(_Connection): + """Maintains a Telnet connection to an ASUS-WRT router.""" + + def __init__(self, host, port, username, password, ap): + """Initialize the Telnet connection properties.""" + super(TelnetConnection, self).__init__() + + self._telnet = None + self._host = host + self._port = port + self._username = username + self._password = password + self._ap = ap + self._prompt_string = None + + def get_result(self): + """Retrieve a single AsusWrtResult through a Telnet connection. + + Connect to the Telnet server if not currently connected, otherwise + use the existing connection. + """ + try: + if not self.connected: + self.connect() + + self._telnet.write('{}\n'.format(_IP_NEIGH_CMD).encode('ascii')) + neighbors = (self._telnet.read_until(self._prompt_string). + split(b'\n')[1:-1]) + if self._ap: + self._telnet.write('{}\n'.format(_ARP_CMD).encode('ascii')) + arp_result = (self._telnet.read_until(self._prompt_string). + split(b'\n')[1:-1]) + self._telnet.write('{}\n'.format(_WL_CMD).encode('ascii')) + leases_result = (self._telnet.read_until(self._prompt_string). + split(b'\n')[1:-1]) + self._telnet.write('{}\n'.format(_NVRAM_CMD).encode('ascii')) + nvram_result = (self._telnet.read_until(self._prompt_string). + split(b'\n')[1].split(b'<')[1:]) + else: + arp_result = [''] + nvram_result = [''] + self._telnet.write('{}\n'.format(_LEASES_CMD).encode('ascii')) + leases_result = (self._telnet.read_until(self._prompt_string). + split(b'\n')[1:-1]) + return AsusWrtResult(neighbors, leases_result, arp_result, + nvram_result) + except EOFError: + _LOGGER.error("Unexpected response from router") + self.disconnect() + return None + except ConnectionRefusedError: + _LOGGER.error("Connection refused by router. Telnet enabled?") + self.disconnect() + return None + except socket.gaierror as exc: + _LOGGER.error("Socket exception: %s", exc) + self.disconnect() + return None + except OSError as exc: + _LOGGER.error("OSError: %s", exc) + self.disconnect() + return None + + def connect(self): + """Connect to the ASUS-WRT Telnet server.""" + self._telnet = telnetlib.Telnet(self._host) + self._telnet.read_until(b'login: ') + self._telnet.write((self._username + '\n').encode('ascii')) + self._telnet.read_until(b'Password: ') + self._telnet.write((self._password + '\n').encode('ascii')) + self._prompt_string = self._telnet.read_until(b'#').split(b'\n')[-1] + + super(TelnetConnection, self).connect() + + def disconnect(self): \ + # pylint: disable=broad-except + """Disconnect the current Telnet connection.""" + try: + self._telnet.write('exit\n'.encode('ascii')) + except Exception: + pass + + super(TelnetConnection, self).disconnect() diff --git a/tests/components/device_tracker/test_asuswrt.py b/tests/components/device_tracker/test_asuswrt.py index 81d3c7a1900..0de5ac67a30 100644 --- a/tests/components/device_tracker/test_asuswrt.py +++ b/tests/components/device_tracker/test_asuswrt.py @@ -135,11 +135,12 @@ class TestComponentsDeviceTrackerASUSWRT(unittest.TestCase): update_mock.start() self.addCleanup(update_mock.stop) asuswrt = device_tracker.asuswrt.AsusWrtDeviceScanner(conf_dict) - asuswrt.ssh_connection() + asuswrt.connection.get_result() self.assertEqual(ssh.login.call_count, 1) self.assertEqual( ssh.login.call_args, - mock.call('fake_host', 'fake_user', port=22, ssh_key=FAKEFILE) + mock.call('fake_host', 'fake_user', + ssh_key=FAKEFILE, port=22) ) def test_ssh_login_with_password(self): @@ -160,11 +161,12 @@ class TestComponentsDeviceTrackerASUSWRT(unittest.TestCase): update_mock.start() self.addCleanup(update_mock.stop) asuswrt = device_tracker.asuswrt.AsusWrtDeviceScanner(conf_dict) - asuswrt.ssh_connection() + asuswrt.connection.get_result() self.assertEqual(ssh.login.call_count, 1) self.assertEqual( ssh.login.call_args, - mock.call('fake_host', 'fake_user', password='fake_pass', port=22) + mock.call('fake_host', 'fake_user', + password='fake_pass', port=22) ) def test_ssh_login_without_password_or_pubkey(self): \ @@ -194,3 +196,75 @@ class TestComponentsDeviceTrackerASUSWRT(unittest.TestCase): assert setup_component(self.hass, DOMAIN, {DOMAIN: conf_dict}) ssh.login.assert_not_called() + + def test_telnet_login_with_password(self): + """Test that login is done with password when configured to.""" + telnet = mock.MagicMock() + telnet_mock = mock.patch('telnetlib.Telnet', return_value=telnet) + telnet_mock.start() + self.addCleanup(telnet_mock.stop) + conf_dict = PLATFORM_SCHEMA({ + CONF_PLATFORM: 'asuswrt', + CONF_PROTOCOL: 'telnet', + CONF_HOST: 'fake_host', + CONF_USERNAME: 'fake_user', + CONF_PASSWORD: 'fake_pass' + }) + update_mock = mock.patch( + 'homeassistant.components.device_tracker.asuswrt.' + 'AsusWrtDeviceScanner.get_asuswrt_data') + update_mock.start() + self.addCleanup(update_mock.stop) + asuswrt = device_tracker.asuswrt.AsusWrtDeviceScanner(conf_dict) + asuswrt.connection.get_result() + self.assertEqual(telnet.read_until.call_count, 5) + self.assertEqual(telnet.write.call_count, 4) + self.assertEqual( + telnet.read_until.call_args_list[0], + mock.call(b'login: ') + ) + self.assertEqual( + telnet.write.call_args_list[0], + mock.call(b'fake_user\n') + ) + self.assertEqual( + telnet.read_until.call_args_list[1], + mock.call(b'Password: ') + ) + self.assertEqual( + telnet.write.call_args_list[1], + mock.call(b'fake_pass\n') + ) + self.assertEqual( + telnet.read_until.call_args_list[2], + mock.call(b'#') + ) + + def test_telnet_login_without_password(self): \ + # pylint: disable=invalid-name + """Test that login is not called without password or pub_key.""" + telnet = mock.MagicMock() + telnet_mock = mock.patch('telnetlib.Telnet', return_value=telnet) + telnet_mock.start() + self.addCleanup(telnet_mock.stop) + + conf_dict = { + CONF_PLATFORM: 'asuswrt', + CONF_PROTOCOL: 'telnet', + CONF_HOST: 'fake_host', + CONF_USERNAME: 'fake_user', + } + + with self.assertRaises(vol.Invalid): + conf_dict = PLATFORM_SCHEMA(conf_dict) + + update_mock = mock.patch( + 'homeassistant.components.device_tracker.asuswrt.' + 'AsusWrtDeviceScanner.get_asuswrt_data') + update_mock.start() + self.addCleanup(update_mock.stop) + + with assert_setup_component(0): + assert setup_component(self.hass, DOMAIN, + {DOMAIN: conf_dict}) + telnet.login.assert_not_called()