diff --git a/hassio/addons/git.py b/hassio/addons/git.py index 4712d9b6d..a00521946 100644 --- a/hassio/addons/git.py +++ b/hassio/addons/git.py @@ -8,8 +8,9 @@ import shutil import git from .utils import get_hash_from_repository -from ..const import URL_HASSIO_ADDONS +from ..const import URL_HASSIO_ADDONS, ATTR_URL, ATTR_BRANCH from ..coresys import CoreSysAttributes +from ..validate import RE_REPOSITORY _LOGGER = logging.getLogger(__name__) @@ -22,9 +23,20 @@ class GitRepo(CoreSysAttributes): self.coresys = coresys self.repo = None self.path = path - self.url = url self.lock = asyncio.Lock(loop=coresys.loop) + self._data = RE_REPOSITORY.match(url).groupdict() + + @property + def url(self): + """Return repository URL.""" + return self._data[ATTR_URL] + + @property + def branch(self): + """Return repository branch.""" + return self._data[ATTR_BRANCH] + async def load(self): """Init git addon repo.""" if not self.path.is_dir(): @@ -46,12 +58,20 @@ class GitRepo(CoreSysAttributes): async def clone(self): """Clone git addon repo.""" async with self.lock: + git_args = { + attribute: value + for attribute, value in ( + ('recursive', True), + ('branch', self.branch) + ) if value is not None + } + try: _LOGGER.info("Clone addon %s repository", self.url) - self.repo = await self._loop.run_in_executor( - None, ft.partial( - git.Repo.clone_from, self.url, str(self.path), - recursive=True)) + self.repo = await self._loop.run_in_executor(None, ft.partial( + git.Repo.clone_from, self.url, str(self.path), + **git_args + )) except (git.InvalidGitRepositoryError, git.NoSuchPathError, git.GitCommandError) as err: diff --git a/hassio/const.py b/hassio/const.py index 2ddb7d737..0e71662ef 100644 --- a/hassio/const.py +++ b/hassio/const.py @@ -158,6 +158,7 @@ ATTR_SERVICES = 'services' ATTR_DISCOVERY = 'discovery' ATTR_PROTECTED = 'protected' ATTR_CRYPTO = 'crypto' +ATTR_BRANCH = 'branch' SERVICE_MQTT = 'mqtt' diff --git a/hassio/validate.py b/hassio/validate.py index ae52aca85..fd4a909be 100644 --- a/hassio/validate.py +++ b/hassio/validate.py @@ -1,5 +1,6 @@ """Validate functions.""" import uuid +import re import voluptuous as vol import pytz @@ -11,13 +12,29 @@ from .const import ( ATTR_SSL, ATTR_PORT, ATTR_WATCHDOG, ATTR_WAIT_BOOT, ATTR_UUID) +RE_REPOSITORY = re.compile(r"^(?P[^#]+)(?:#(?P[\w\-]+))?$") + NETWORK_PORT = vol.All(vol.Coerce(int), vol.Range(min=1, max=65535)) ALSA_CHANNEL = vol.Match(r"\d+,\d+") WAIT_BOOT = vol.All(vol.Coerce(int), vol.Range(min=1, max=60)) DOCKER_IMAGE = vol.Match(r"^[\w{}]+/[\-\w{}]+$") + +def validate_repository(repository): + """Validate a valide repository.""" + data = RE_REPOSITORY.match(repository) + if not data: + raise vol.Invalid("No valid repository format!") + + # Validate URL + # pylint: disable=no-value-for-parameter + vol.Url()(data.group('url')) + + return repository + + # pylint: disable=no-value-for-parameter -REPOSITORIES = vol.All([vol.Url()], vol.Unique()) +REPOSITORIES = vol.All([validate_repository], vol.Unique()) def validate_timezone(timezone):