mirror of
https://github.com/home-assistant/supervisor.git
synced 2025-12-10 18:08:58 +00:00
Add registry manifest fetcher for size-based pull progress
Fetch image manifests directly from container registries before pulling to get accurate layer sizes upfront. This enables size-weighted progress tracking where each layer contributes proportionally to its byte size, rather than equal weight per layer. Key changes: - Add RegistryManifestFetcher that handles auth discovery via WWW-Authenticate headers, token fetching with optional credentials, and multi-arch manifest list resolution - Update ImagePullProgress to accept manifest layer sizes via set_manifest() and calculate size-weighted progress - Fall back to count-based progress when manifest fetch fails - Pre-populate layer sizes from manifest when creating layer trackers The manifest fetcher supports ghcr.io, Docker Hub, and private registries by using credentials from Docker config when available. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -213,9 +213,26 @@ class DockerInterface(JobGroup, ABC):
|
|||||||
raise ValueError("Cannot pull without an image!")
|
raise ValueError("Cannot pull without an image!")
|
||||||
|
|
||||||
image_arch = arch or self.sys_arch.supervisor
|
image_arch = arch or self.sys_arch.supervisor
|
||||||
|
platform = MAP_ARCH[image_arch]
|
||||||
pull_progress = ImagePullProgress()
|
pull_progress = ImagePullProgress()
|
||||||
current_job = self.sys_jobs.current
|
current_job = self.sys_jobs.current
|
||||||
|
|
||||||
|
# Try to fetch manifest for accurate size-based progress
|
||||||
|
# This is optional - if it fails, we fall back to count-based progress
|
||||||
|
try:
|
||||||
|
manifest = await self.sys_docker.manifest_fetcher.get_manifest(
|
||||||
|
image, str(version), platform=platform
|
||||||
|
)
|
||||||
|
if manifest:
|
||||||
|
pull_progress.set_manifest(manifest)
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Using manifest for progress: %d layers, %d bytes",
|
||||||
|
manifest.layer_count,
|
||||||
|
manifest.total_size,
|
||||||
|
)
|
||||||
|
except Exception as err: # noqa: BLE001
|
||||||
|
_LOGGER.debug("Could not fetch manifest for progress: %s", err)
|
||||||
|
|
||||||
async def process_pull_event(event: PullLogEntry) -> None:
|
async def process_pull_event(event: PullLogEntry) -> None:
|
||||||
"""Process pull event and update job progress."""
|
"""Process pull event and update job progress."""
|
||||||
if event.job_id != current_job.uuid:
|
if event.job_id != current_job.uuid:
|
||||||
@@ -244,7 +261,7 @@ class DockerInterface(JobGroup, ABC):
|
|||||||
current_job.uuid,
|
current_job.uuid,
|
||||||
image,
|
image,
|
||||||
str(version),
|
str(version),
|
||||||
platform=MAP_ARCH[image_arch],
|
platform=platform,
|
||||||
auth=credentials,
|
auth=credentials,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ from ..exceptions import (
|
|||||||
from ..utils.common import FileConfiguration
|
from ..utils.common import FileConfiguration
|
||||||
from ..validate import SCHEMA_DOCKER_CONFIG
|
from ..validate import SCHEMA_DOCKER_CONFIG
|
||||||
from .const import DOCKER_HUB, DOCKER_HUB_LEGACY, LABEL_MANAGED
|
from .const import DOCKER_HUB, DOCKER_HUB_LEGACY, LABEL_MANAGED
|
||||||
|
from .manifest import RegistryManifestFetcher
|
||||||
from .monitor import DockerMonitor
|
from .monitor import DockerMonitor
|
||||||
from .network import DockerNetwork
|
from .network import DockerNetwork
|
||||||
from .utils import get_registry_from_image
|
from .utils import get_registry_from_image
|
||||||
@@ -258,6 +259,9 @@ class DockerAPI(CoreSysAttributes):
|
|||||||
self._info: DockerInfo | None = None
|
self._info: DockerInfo | None = None
|
||||||
self.config: DockerConfig = DockerConfig()
|
self.config: DockerConfig = DockerConfig()
|
||||||
self._monitor: DockerMonitor = DockerMonitor(coresys)
|
self._monitor: DockerMonitor = DockerMonitor(coresys)
|
||||||
|
self._manifest_fetcher: RegistryManifestFetcher = RegistryManifestFetcher(
|
||||||
|
coresys
|
||||||
|
)
|
||||||
|
|
||||||
async def post_init(self) -> Self:
|
async def post_init(self) -> Self:
|
||||||
"""Post init actions that must be done in event loop."""
|
"""Post init actions that must be done in event loop."""
|
||||||
@@ -323,6 +327,11 @@ class DockerAPI(CoreSysAttributes):
|
|||||||
"""Return docker events monitor."""
|
"""Return docker events monitor."""
|
||||||
return self._monitor
|
return self._monitor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def manifest_fetcher(self) -> RegistryManifestFetcher:
|
||||||
|
"""Return manifest fetcher for registry access."""
|
||||||
|
return self._manifest_fetcher
|
||||||
|
|
||||||
async def load(self) -> None:
|
async def load(self) -> None:
|
||||||
"""Start docker events monitor."""
|
"""Start docker events monitor."""
|
||||||
await self.monitor.load()
|
await self.monitor.load()
|
||||||
|
|||||||
354
supervisor/docker/manifest.py
Normal file
354
supervisor/docker/manifest.py
Normal file
@@ -0,0 +1,354 @@
|
|||||||
|
"""Docker registry manifest fetcher.
|
||||||
|
|
||||||
|
Fetches image manifests directly from container registries to get layer sizes
|
||||||
|
before pulling an image. This enables accurate size-based progress tracking.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from .const import DOCKER_HUB, IMAGE_WITH_HOST
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..coresys import CoreSys
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Default registry for images without explicit host
|
||||||
|
DEFAULT_REGISTRY = "registry-1.docker.io"
|
||||||
|
|
||||||
|
# Media types for manifest requests
|
||||||
|
MANIFEST_MEDIA_TYPES = (
|
||||||
|
"application/vnd.docker.distribution.manifest.v2+json",
|
||||||
|
"application/vnd.oci.image.manifest.v1+json",
|
||||||
|
"application/vnd.docker.distribution.manifest.list.v2+json",
|
||||||
|
"application/vnd.oci.image.index.v1+json",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImageManifest:
|
||||||
|
"""Container image manifest with layer information."""
|
||||||
|
|
||||||
|
digest: str
|
||||||
|
total_size: int
|
||||||
|
layers: dict[str, int] # digest -> size in bytes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layer_count(self) -> int:
|
||||||
|
"""Return number of layers."""
|
||||||
|
return len(self.layers)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_image_reference(image: str, tag: str) -> tuple[str, str, str]:
|
||||||
|
"""Parse image reference into (registry, repository, tag).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
ghcr.io/home-assistant/home-assistant:2025.1.0
|
||||||
|
-> (ghcr.io, home-assistant/home-assistant, 2025.1.0)
|
||||||
|
homeassistant/home-assistant:latest
|
||||||
|
-> (registry-1.docker.io, homeassistant/home-assistant, latest)
|
||||||
|
alpine:3.18
|
||||||
|
-> (registry-1.docker.io, library/alpine, 3.18)
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Check if image has explicit registry host
|
||||||
|
match = IMAGE_WITH_HOST.match(image)
|
||||||
|
if match:
|
||||||
|
registry = match.group(1)
|
||||||
|
repository = image[len(registry) + 1 :] # Remove "registry/" prefix
|
||||||
|
else:
|
||||||
|
registry = DEFAULT_REGISTRY
|
||||||
|
repository = image
|
||||||
|
# Docker Hub requires "library/" prefix for official images
|
||||||
|
if "/" not in repository:
|
||||||
|
repository = f"library/{repository}"
|
||||||
|
|
||||||
|
return registry, repository, tag
|
||||||
|
|
||||||
|
|
||||||
|
class RegistryManifestFetcher:
|
||||||
|
"""Fetches manifests from container registries."""
|
||||||
|
|
||||||
|
def __init__(self, coresys: CoreSys) -> None:
|
||||||
|
"""Initialize the fetcher."""
|
||||||
|
self.coresys = coresys
|
||||||
|
self._session: aiohttp.ClientSession | None = None
|
||||||
|
|
||||||
|
async def _get_session(self) -> aiohttp.ClientSession:
|
||||||
|
"""Get or create aiohttp session."""
|
||||||
|
if self._session is None or self._session.closed:
|
||||||
|
self._session = aiohttp.ClientSession()
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the session."""
|
||||||
|
if self._session and not self._session.closed:
|
||||||
|
await self._session.close()
|
||||||
|
self._session = None
|
||||||
|
|
||||||
|
def _get_credentials(self, registry: str) -> tuple[str, str] | None:
|
||||||
|
"""Get credentials for registry from Docker config.
|
||||||
|
|
||||||
|
Returns (username, password) tuple or None if no credentials.
|
||||||
|
"""
|
||||||
|
registries = self.coresys.docker.config.registries
|
||||||
|
|
||||||
|
# Map registry hostname to config key
|
||||||
|
# Docker Hub can be stored as "hub.docker.com" in config
|
||||||
|
if registry == DEFAULT_REGISTRY:
|
||||||
|
if DOCKER_HUB in registries:
|
||||||
|
creds = registries[DOCKER_HUB]
|
||||||
|
return creds.get("username"), creds.get("password")
|
||||||
|
elif registry in registries:
|
||||||
|
creds = registries[registry]
|
||||||
|
return creds.get("username"), creds.get("password")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _get_auth_token(
|
||||||
|
self,
|
||||||
|
session: aiohttp.ClientSession,
|
||||||
|
registry: str,
|
||||||
|
repository: str,
|
||||||
|
) -> str | None:
|
||||||
|
"""Get authentication token for registry.
|
||||||
|
|
||||||
|
Uses the WWW-Authenticate header from a 401 response to discover
|
||||||
|
the token endpoint, then requests a token with appropriate scope.
|
||||||
|
"""
|
||||||
|
# First, make an unauthenticated request to get WWW-Authenticate header
|
||||||
|
manifest_url = f"https://{registry}/v2/{repository}/manifests/latest"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.get(manifest_url) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
# No auth required
|
||||||
|
return None
|
||||||
|
|
||||||
|
if resp.status != 401:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Unexpected status %d from registry %s", resp.status, registry
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
www_auth = resp.headers.get("WWW-Authenticate", "")
|
||||||
|
except aiohttp.ClientError as err:
|
||||||
|
_LOGGER.warning("Failed to connect to registry %s: %s", registry, err)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Parse WWW-Authenticate: Bearer realm="...",service="...",scope="..."
|
||||||
|
if not www_auth.startswith("Bearer "):
|
||||||
|
_LOGGER.warning("Unsupported auth type from %s: %s", registry, www_auth)
|
||||||
|
return None
|
||||||
|
|
||||||
|
params = {}
|
||||||
|
for match in re.finditer(r'(\w+)="([^"]*)"', www_auth):
|
||||||
|
params[match.group(1)] = match.group(2)
|
||||||
|
|
||||||
|
realm = params.get("realm")
|
||||||
|
service = params.get("service")
|
||||||
|
|
||||||
|
if not realm:
|
||||||
|
_LOGGER.warning("No realm in WWW-Authenticate from %s", registry)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Build token request URL
|
||||||
|
token_url = f"{realm}?scope=repository:{repository}:pull"
|
||||||
|
if service:
|
||||||
|
token_url += f"&service={service}"
|
||||||
|
|
||||||
|
# Check for credentials
|
||||||
|
auth = None
|
||||||
|
credentials = self._get_credentials(registry)
|
||||||
|
if credentials:
|
||||||
|
username, password = credentials
|
||||||
|
if username and password:
|
||||||
|
auth = aiohttp.BasicAuth(username, password)
|
||||||
|
_LOGGER.debug("Using credentials for %s", registry)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.get(token_url, auth=auth) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Failed to get token from %s: %d", realm, resp.status
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = await resp.json()
|
||||||
|
return data.get("token") or data.get("access_token")
|
||||||
|
except aiohttp.ClientError as err:
|
||||||
|
_LOGGER.warning("Failed to get auth token: %s", err)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _fetch_manifest(
|
||||||
|
self,
|
||||||
|
session: aiohttp.ClientSession,
|
||||||
|
registry: str,
|
||||||
|
repository: str,
|
||||||
|
reference: str,
|
||||||
|
token: str | None,
|
||||||
|
platform: str | None = None,
|
||||||
|
) -> dict | None:
|
||||||
|
"""Fetch manifest from registry.
|
||||||
|
|
||||||
|
If the manifest is a manifest list (multi-arch), fetches the
|
||||||
|
platform-specific manifest.
|
||||||
|
"""
|
||||||
|
manifest_url = f"https://{registry}/v2/{repository}/manifests/{reference}"
|
||||||
|
|
||||||
|
headers = {"Accept": ", ".join(MANIFEST_MEDIA_TYPES)}
|
||||||
|
if token:
|
||||||
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.get(manifest_url, headers=headers) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Failed to fetch manifest for %s/%s:%s - %d",
|
||||||
|
registry,
|
||||||
|
repository,
|
||||||
|
reference,
|
||||||
|
resp.status,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
manifest = await resp.json()
|
||||||
|
except aiohttp.ClientError as err:
|
||||||
|
_LOGGER.warning("Failed to fetch manifest: %s", err)
|
||||||
|
return None
|
||||||
|
|
||||||
|
media_type = manifest.get("mediaType", "")
|
||||||
|
|
||||||
|
# Check if this is a manifest list (multi-arch image)
|
||||||
|
if "list" in media_type or "index" in media_type:
|
||||||
|
manifests = manifest.get("manifests", [])
|
||||||
|
if not manifests:
|
||||||
|
_LOGGER.warning("Empty manifest list for %s/%s", registry, repository)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find matching platform
|
||||||
|
target_os = "linux"
|
||||||
|
target_arch = "amd64" # Default
|
||||||
|
|
||||||
|
if platform:
|
||||||
|
# Platform format is "linux/amd64", "linux/arm64", etc.
|
||||||
|
parts = platform.split("/")
|
||||||
|
if len(parts) >= 2:
|
||||||
|
target_os, target_arch = parts[0], parts[1]
|
||||||
|
|
||||||
|
platform_manifest = None
|
||||||
|
for m in manifests:
|
||||||
|
plat = m.get("platform", {})
|
||||||
|
if (
|
||||||
|
plat.get("os") == target_os
|
||||||
|
and plat.get("architecture") == target_arch
|
||||||
|
):
|
||||||
|
platform_manifest = m
|
||||||
|
break
|
||||||
|
|
||||||
|
if not platform_manifest:
|
||||||
|
# Fall back to first manifest
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Platform %s/%s not found, using first manifest",
|
||||||
|
target_os,
|
||||||
|
target_arch,
|
||||||
|
)
|
||||||
|
platform_manifest = manifests[0]
|
||||||
|
|
||||||
|
# Fetch the platform-specific manifest
|
||||||
|
return await self._fetch_manifest(
|
||||||
|
session,
|
||||||
|
registry,
|
||||||
|
repository,
|
||||||
|
platform_manifest["digest"],
|
||||||
|
token,
|
||||||
|
platform,
|
||||||
|
)
|
||||||
|
|
||||||
|
return manifest
|
||||||
|
|
||||||
|
async def get_manifest(
|
||||||
|
self,
|
||||||
|
image: str,
|
||||||
|
tag: str,
|
||||||
|
platform: str | None = None,
|
||||||
|
) -> ImageManifest | None:
|
||||||
|
"""Fetch manifest and extract layer sizes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Image name (e.g., "ghcr.io/home-assistant/home-assistant")
|
||||||
|
tag: Image tag (e.g., "2025.1.0")
|
||||||
|
platform: Target platform (e.g., "linux/amd64")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImageManifest with layer sizes, or None if fetch failed.
|
||||||
|
|
||||||
|
"""
|
||||||
|
registry, repository, tag = parse_image_reference(image, tag)
|
||||||
|
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Fetching manifest for %s/%s:%s (platform=%s)",
|
||||||
|
registry,
|
||||||
|
repository,
|
||||||
|
tag,
|
||||||
|
platform,
|
||||||
|
)
|
||||||
|
|
||||||
|
session = await self._get_session()
|
||||||
|
|
||||||
|
# Get auth token
|
||||||
|
token = await self._get_auth_token(session, registry, repository)
|
||||||
|
|
||||||
|
# Fetch manifest
|
||||||
|
manifest = await self._fetch_manifest(
|
||||||
|
session, registry, repository, tag, token, platform
|
||||||
|
)
|
||||||
|
|
||||||
|
if not manifest:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract layer information
|
||||||
|
layers = manifest.get("layers", [])
|
||||||
|
if not layers:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"No layers in manifest for %s/%s:%s", registry, repository, tag
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
layer_sizes: dict[str, int] = {}
|
||||||
|
total_size = 0
|
||||||
|
|
||||||
|
for layer in layers:
|
||||||
|
digest = layer.get("digest", "")
|
||||||
|
size = layer.get("size", 0)
|
||||||
|
if digest and size:
|
||||||
|
# Store by short digest (first 12 chars after sha256:)
|
||||||
|
short_digest = (
|
||||||
|
digest.split(":")[1][:12] if ":" in digest else digest[:12]
|
||||||
|
)
|
||||||
|
layer_sizes[short_digest] = size
|
||||||
|
total_size += size
|
||||||
|
|
||||||
|
digest = manifest.get("config", {}).get("digest", "")
|
||||||
|
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Manifest for %s/%s:%s - %d layers, %d bytes total",
|
||||||
|
registry,
|
||||||
|
repository,
|
||||||
|
tag,
|
||||||
|
len(layer_sizes),
|
||||||
|
total_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageManifest(
|
||||||
|
digest=digest,
|
||||||
|
total_size=total_size,
|
||||||
|
layers=layer_sizes,
|
||||||
|
)
|
||||||
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, cast
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .manager import PullLogEntry
|
from .manager import PullLogEntry
|
||||||
|
from .manifest import ImageManifest
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -109,23 +110,43 @@ class LayerProgress:
|
|||||||
class ImagePullProgress:
|
class ImagePullProgress:
|
||||||
"""Track overall progress of pulling an image.
|
"""Track overall progress of pulling an image.
|
||||||
|
|
||||||
Uses count-based progress where each layer contributes equally regardless of size.
|
When manifest layer sizes are provided, uses size-weighted progress where
|
||||||
This avoids progress regression when large layers are discovered late due to
|
each layer contributes proportionally to its size. This gives accurate
|
||||||
Docker's rate-limiting of concurrent downloads.
|
progress based on actual bytes to download.
|
||||||
|
|
||||||
Progress is only reported after the first "Downloading" event, since Docker
|
When manifest is not available, falls back to count-based progress where
|
||||||
sends "Already exists" and "Pulling fs layer" events before we know the full
|
each layer contributes equally.
|
||||||
layer count.
|
|
||||||
|
Layers that already exist locally are excluded from the progress calculation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
layers: dict[str, LayerProgress] = field(default_factory=dict)
|
layers: dict[str, LayerProgress] = field(default_factory=dict)
|
||||||
_last_reported_progress: float = field(default=0.0, repr=False)
|
_last_reported_progress: float = field(default=0.0, repr=False)
|
||||||
_seen_downloading: bool = field(default=False, repr=False)
|
_seen_downloading: bool = field(default=False, repr=False)
|
||||||
|
_manifest_layer_sizes: dict[str, int] = field(default_factory=dict, repr=False)
|
||||||
|
_total_manifest_size: int = field(default=0, repr=False)
|
||||||
|
|
||||||
|
def set_manifest(self, manifest: ImageManifest) -> None:
|
||||||
|
"""Set manifest layer sizes for accurate size-based progress.
|
||||||
|
|
||||||
|
Should be called before processing pull events.
|
||||||
|
"""
|
||||||
|
self._manifest_layer_sizes = dict(manifest.layers)
|
||||||
|
self._total_manifest_size = manifest.total_size
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Manifest set: %d layers, %d bytes total",
|
||||||
|
len(self._manifest_layer_sizes),
|
||||||
|
self._total_manifest_size,
|
||||||
|
)
|
||||||
|
|
||||||
def get_or_create_layer(self, layer_id: str) -> LayerProgress:
|
def get_or_create_layer(self, layer_id: str) -> LayerProgress:
|
||||||
"""Get existing layer or create new one."""
|
"""Get existing layer or create new one."""
|
||||||
if layer_id not in self.layers:
|
if layer_id not in self.layers:
|
||||||
self.layers[layer_id] = LayerProgress(layer_id=layer_id)
|
# If we have manifest sizes, pre-populate the layer's total_size
|
||||||
|
manifest_size = self._manifest_layer_sizes.get(layer_id, 0)
|
||||||
|
self.layers[layer_id] = LayerProgress(
|
||||||
|
layer_id=layer_id, total_size=manifest_size
|
||||||
|
)
|
||||||
return self.layers[layer_id]
|
return self.layers[layer_id]
|
||||||
|
|
||||||
def process_event(self, entry: PullLogEntry) -> None:
|
def process_event(self, entry: PullLogEntry) -> None:
|
||||||
@@ -237,8 +258,13 @@ class ImagePullProgress:
|
|||||||
def calculate_progress(self) -> float:
|
def calculate_progress(self) -> float:
|
||||||
"""Calculate overall progress 0-100.
|
"""Calculate overall progress 0-100.
|
||||||
|
|
||||||
Uses count-based progress where each layer that needs pulling contributes
|
When manifest layer sizes are available, uses size-weighted progress
|
||||||
equally. Layers that already exist locally are excluded from the calculation.
|
where each layer contributes proportionally to its size.
|
||||||
|
|
||||||
|
When manifest is not available, falls back to count-based progress
|
||||||
|
where each layer contributes equally.
|
||||||
|
|
||||||
|
Layers that already exist locally are excluded from the calculation.
|
||||||
|
|
||||||
Returns 0 until we've seen the first "Downloading" event, since Docker
|
Returns 0 until we've seen the first "Downloading" event, since Docker
|
||||||
reports "Already exists" and "Pulling fs layer" events before we know
|
reports "Already exists" and "Pulling fs layer" events before we know
|
||||||
@@ -258,10 +284,39 @@ class ImagePullProgress:
|
|||||||
# All layers already exist, nothing to download
|
# All layers already exist, nothing to download
|
||||||
return 100.0
|
return 100.0
|
||||||
|
|
||||||
# Each layer contributes equally: sum of layer progresses / total layers
|
# Use size-weighted progress if manifest sizes are available
|
||||||
|
if self._manifest_layer_sizes:
|
||||||
|
return self._calculate_size_weighted_progress(layers_to_pull)
|
||||||
|
|
||||||
|
# Fall back to count-based progress
|
||||||
total_progress = sum(layer.calculate_progress() for layer in layers_to_pull)
|
total_progress = sum(layer.calculate_progress() for layer in layers_to_pull)
|
||||||
return total_progress / len(layers_to_pull)
|
return total_progress / len(layers_to_pull)
|
||||||
|
|
||||||
|
def _calculate_size_weighted_progress(
|
||||||
|
self, layers_to_pull: list[LayerProgress]
|
||||||
|
) -> float:
|
||||||
|
"""Calculate size-weighted progress.
|
||||||
|
|
||||||
|
Each layer contributes to progress proportionally to its size.
|
||||||
|
Progress = sum(layer_progress * layer_size) / total_size
|
||||||
|
"""
|
||||||
|
# Calculate total size of layers that need pulling
|
||||||
|
total_size = sum(layer.total_size for layer in layers_to_pull)
|
||||||
|
|
||||||
|
if total_size == 0:
|
||||||
|
# No size info available, fall back to count-based
|
||||||
|
total_progress = sum(layer.calculate_progress() for layer in layers_to_pull)
|
||||||
|
return total_progress / len(layers_to_pull)
|
||||||
|
|
||||||
|
# Weight each layer's progress by its size
|
||||||
|
weighted_progress = 0.0
|
||||||
|
for layer in layers_to_pull:
|
||||||
|
if layer.total_size > 0:
|
||||||
|
layer_weight = layer.total_size / total_size
|
||||||
|
weighted_progress += layer.calculate_progress() * layer_weight
|
||||||
|
|
||||||
|
return weighted_progress
|
||||||
|
|
||||||
def get_stage(self) -> str | None:
|
def get_stage(self) -> str | None:
|
||||||
"""Get current stage based on layer states."""
|
"""Get current stage based on layer states."""
|
||||||
if not self.layers:
|
if not self.layers:
|
||||||
|
|||||||
@@ -154,6 +154,9 @@ async def docker() -> DockerAPI:
|
|||||||
docker_obj.info.storage = "overlay2"
|
docker_obj.info.storage = "overlay2"
|
||||||
docker_obj.info.version = AwesomeVersion("1.0.0")
|
docker_obj.info.version = AwesomeVersion("1.0.0")
|
||||||
|
|
||||||
|
# Mock manifest fetcher to return None (falls back to count-based progress)
|
||||||
|
docker_obj._manifest_fetcher.get_manifest = AsyncMock(return_value=None)
|
||||||
|
|
||||||
yield docker_obj
|
yield docker_obj
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
164
tests/docker/test_manifest.py
Normal file
164
tests/docker/test_manifest.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
"""Tests for registry manifest fetcher."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from supervisor.docker.manifest import (
|
||||||
|
DEFAULT_REGISTRY,
|
||||||
|
ImageManifest,
|
||||||
|
RegistryManifestFetcher,
|
||||||
|
parse_image_reference,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseImageReference:
|
||||||
|
"""Tests for parse_image_reference function."""
|
||||||
|
|
||||||
|
def test_ghcr_io_image(self):
|
||||||
|
"""Test parsing ghcr.io image."""
|
||||||
|
registry, repo, tag = parse_image_reference(
|
||||||
|
"ghcr.io/home-assistant/home-assistant", "2025.1.0"
|
||||||
|
)
|
||||||
|
assert registry == "ghcr.io"
|
||||||
|
assert repo == "home-assistant/home-assistant"
|
||||||
|
assert tag == "2025.1.0"
|
||||||
|
|
||||||
|
def test_docker_hub_with_org(self):
|
||||||
|
"""Test parsing Docker Hub image with organization."""
|
||||||
|
registry, repo, tag = parse_image_reference(
|
||||||
|
"homeassistant/home-assistant", "latest"
|
||||||
|
)
|
||||||
|
assert registry == DEFAULT_REGISTRY
|
||||||
|
assert repo == "homeassistant/home-assistant"
|
||||||
|
assert tag == "latest"
|
||||||
|
|
||||||
|
def test_docker_hub_official_image(self):
|
||||||
|
"""Test parsing Docker Hub official image (no org)."""
|
||||||
|
registry, repo, tag = parse_image_reference("alpine", "3.18")
|
||||||
|
assert registry == DEFAULT_REGISTRY
|
||||||
|
assert repo == "library/alpine"
|
||||||
|
assert tag == "3.18"
|
||||||
|
|
||||||
|
def test_gcr_io_image(self):
|
||||||
|
"""Test parsing gcr.io image."""
|
||||||
|
registry, repo, tag = parse_image_reference("gcr.io/project/image", "v1")
|
||||||
|
assert registry == "gcr.io"
|
||||||
|
assert repo == "project/image"
|
||||||
|
assert tag == "v1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestImageManifest:
|
||||||
|
"""Tests for ImageManifest dataclass."""
|
||||||
|
|
||||||
|
def test_layer_count(self):
|
||||||
|
"""Test layer_count property."""
|
||||||
|
manifest = ImageManifest(
|
||||||
|
digest="sha256:abc",
|
||||||
|
total_size=1000,
|
||||||
|
layers={"layer1": 500, "layer2": 500},
|
||||||
|
)
|
||||||
|
assert manifest.layer_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegistryManifestFetcher:
|
||||||
|
"""Tests for RegistryManifestFetcher class."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_coresys(self):
|
||||||
|
"""Create mock coresys."""
|
||||||
|
coresys = MagicMock()
|
||||||
|
coresys.docker.config.registries = {}
|
||||||
|
return coresys
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fetcher(self, mock_coresys):
|
||||||
|
"""Create fetcher instance."""
|
||||||
|
return RegistryManifestFetcher(mock_coresys)
|
||||||
|
|
||||||
|
async def test_get_manifest_success(self, fetcher):
|
||||||
|
"""Test successful manifest fetch by mocking internal methods."""
|
||||||
|
manifest_data = {
|
||||||
|
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
|
||||||
|
"config": {"digest": "sha256:abc123"},
|
||||||
|
"layers": [
|
||||||
|
{"digest": "sha256:layer1abc123def456789012", "size": 1000},
|
||||||
|
{"digest": "sha256:layer2def456abc789012345", "size": 2000},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mock the internal methods instead of the session
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
fetcher, "_get_auth_token", new=AsyncMock(return_value="test-token")
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
fetcher, "_fetch_manifest", new=AsyncMock(return_value=manifest_data)
|
||||||
|
),
|
||||||
|
patch.object(fetcher, "_get_session", new=AsyncMock()),
|
||||||
|
):
|
||||||
|
result = await fetcher.get_manifest(
|
||||||
|
"test.io/org/image", "v1.0", platform="linux/amd64"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.total_size == 3000
|
||||||
|
assert result.layer_count == 2
|
||||||
|
# First 12 chars after sha256:
|
||||||
|
assert "layer1abc123" in result.layers
|
||||||
|
assert result.layers["layer1abc123"] == 1000
|
||||||
|
|
||||||
|
async def test_get_manifest_returns_none_on_failure(self, fetcher):
|
||||||
|
"""Test that get_manifest returns None on failure."""
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
fetcher, "_get_auth_token", new=AsyncMock(return_value="test-token")
|
||||||
|
),
|
||||||
|
patch.object(fetcher, "_fetch_manifest", new=AsyncMock(return_value=None)),
|
||||||
|
patch.object(fetcher, "_get_session", new=AsyncMock()),
|
||||||
|
):
|
||||||
|
result = await fetcher.get_manifest(
|
||||||
|
"test.io/org/image", "v1.0", platform="linux/amd64"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_close_session(self, fetcher):
|
||||||
|
"""Test session cleanup."""
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.closed = False
|
||||||
|
mock_session.close = AsyncMock()
|
||||||
|
fetcher._session = mock_session
|
||||||
|
|
||||||
|
await fetcher.close()
|
||||||
|
|
||||||
|
mock_session.close.assert_called_once()
|
||||||
|
assert fetcher._session is None
|
||||||
|
|
||||||
|
def test_get_credentials_docker_hub(self, mock_coresys, fetcher):
|
||||||
|
"""Test getting Docker Hub credentials."""
|
||||||
|
mock_coresys.docker.config.registries = {
|
||||||
|
"hub.docker.com": {"username": "user", "password": "pass"}
|
||||||
|
}
|
||||||
|
|
||||||
|
creds = fetcher._get_credentials(DEFAULT_REGISTRY)
|
||||||
|
|
||||||
|
assert creds == ("user", "pass")
|
||||||
|
|
||||||
|
def test_get_credentials_custom_registry(self, mock_coresys, fetcher):
|
||||||
|
"""Test getting credentials for custom registry."""
|
||||||
|
mock_coresys.docker.config.registries = {
|
||||||
|
"ghcr.io": {"username": "user", "password": "token"}
|
||||||
|
}
|
||||||
|
|
||||||
|
creds = fetcher._get_credentials("ghcr.io")
|
||||||
|
|
||||||
|
assert creds == ("user", "token")
|
||||||
|
|
||||||
|
def test_get_credentials_not_found(self, mock_coresys, fetcher):
|
||||||
|
"""Test no credentials found."""
|
||||||
|
mock_coresys.docker.config.registries = {}
|
||||||
|
|
||||||
|
creds = fetcher._get_credentials("unknown.io")
|
||||||
|
|
||||||
|
assert creds is None
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from supervisor.docker.manager import PullLogEntry, PullProgressDetail
|
from supervisor.docker.manager import PullLogEntry, PullProgressDetail
|
||||||
|
from supervisor.docker.manifest import ImageManifest
|
||||||
from supervisor.docker.pull_progress import (
|
from supervisor.docker.pull_progress import (
|
||||||
DOWNLOAD_WEIGHT,
|
DOWNLOAD_WEIGHT,
|
||||||
EXTRACT_WEIGHT,
|
EXTRACT_WEIGHT,
|
||||||
@@ -784,3 +785,218 @@ class TestImagePullProgress:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert progress.calculate_progress() == 100.0
|
assert progress.calculate_progress() == 100.0
|
||||||
|
|
||||||
|
def test_size_weighted_progress_with_manifest(self):
|
||||||
|
"""Test size-weighted progress when manifest layer sizes are known."""
|
||||||
|
# Create manifest with known layer sizes
|
||||||
|
# Small layer: 1KB, Large layer: 100KB
|
||||||
|
manifest = ImageManifest(
|
||||||
|
digest="sha256:test",
|
||||||
|
total_size=101000,
|
||||||
|
layers={
|
||||||
|
"small123456": 1000, # 1KB - ~1% of total
|
||||||
|
"large123456": 100000, # 100KB - ~99% of total
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
progress = ImagePullProgress()
|
||||||
|
progress.set_manifest(manifest)
|
||||||
|
|
||||||
|
# Layer events - small layer first
|
||||||
|
progress.process_event(
|
||||||
|
PullLogEntry(
|
||||||
|
job_id="test",
|
||||||
|
id="small123456",
|
||||||
|
status="Pulling fs layer",
|
||||||
|
progress_detail=PullProgressDetail(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
progress.process_event(
|
||||||
|
PullLogEntry(
|
||||||
|
job_id="test",
|
||||||
|
id="large123456",
|
||||||
|
status="Pulling fs layer",
|
||||||
|
progress_detail=PullProgressDetail(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Small layer downloads completely
|
||||||
|
progress.process_event(
|
||||||
|
PullLogEntry(
|
||||||
|
job_id="test",
|
||||||
|
id="small123456",
|
||||||
|
status="Downloading",
|
||||||
|
progress_detail=PullProgressDetail(current=1000, total=1000),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Size-weighted: small layer is ~1% of total size
|
||||||
|
# Small layer at 70% (download done) = contributes ~0.7% to overall
|
||||||
|
assert progress.calculate_progress() == pytest.approx(0.69, rel=0.1)
|
||||||
|
|
||||||
|
# Large layer starts downloading (1% of its size)
|
||||||
|
progress.process_event(
|
||||||
|
PullLogEntry(
|
||||||
|
job_id="test",
|
||||||
|
id="large123456",
|
||||||
|
status="Downloading",
|
||||||
|
progress_detail=PullProgressDetail(current=1000, total=100000),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Large layer at 1% download = contributes ~0.7% (1% * 70% * 99% weight)
|
||||||
|
# Total: ~0.7% + ~0.7% = ~1.4%
|
||||||
|
current = progress.calculate_progress()
|
||||||
|
assert current > 0.7 # More than just small layer
|
||||||
|
assert current < 5.0 # But not much more
|
||||||
|
|
||||||
|
# Complete both layers
|
||||||
|
progress.process_event(
|
||||||
|
PullLogEntry(
|
||||||
|
job_id="test",
|
||||||
|
id="small123456",
|
||||||
|
status="Pull complete",
|
||||||
|
progress_detail=PullProgressDetail(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
progress.process_event(
|
||||||
|
PullLogEntry(
|
||||||
|
job_id="test",
|
||||||
|
id="large123456",
|
||||||
|
status="Pull complete",
|
||||||
|
progress_detail=PullProgressDetail(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert progress.calculate_progress() == 100.0
|
||||||
|
|
||||||
|
def test_size_weighted_excludes_already_exists(self):
|
||||||
|
"""Test that already existing layers are excluded from size-weighted progress."""
|
||||||
|
# Manifest has 3 layers, but one will already exist locally
|
||||||
|
manifest = ImageManifest(
|
||||||
|
digest="sha256:test",
|
||||||
|
total_size=200000,
|
||||||
|
layers={
|
||||||
|
"cached12345": 100000, # Will be cached - shouldn't count
|
||||||
|
"layer1_1234": 50000, # Needs pulling
|
||||||
|
"layer2_1234": 50000, # Needs pulling
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
progress = ImagePullProgress()
|
||||||
|
progress.set_manifest(manifest)
|
||||||
|
|
||||||
|
# Cached layer already exists
|
||||||
|
progress.process_event(
|
||||||
|
PullLogEntry(
|
||||||
|
job_id="test",
|
||||||
|
id="cached12345",
|
||||||
|
status="Already exists",
|
||||||
|
progress_detail=PullProgressDetail(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Other layers need pulling
|
||||||
|
progress.process_event(
|
||||||
|
PullLogEntry(
|
||||||
|
job_id="test",
|
||||||
|
id="layer1_1234",
|
||||||
|
status="Pulling fs layer",
|
||||||
|
progress_detail=PullProgressDetail(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
progress.process_event(
|
||||||
|
PullLogEntry(
|
||||||
|
job_id="test",
|
||||||
|
id="layer2_1234",
|
||||||
|
status="Pulling fs layer",
|
||||||
|
progress_detail=PullProgressDetail(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start downloading layer1 (50% of its size)
|
||||||
|
progress.process_event(
|
||||||
|
PullLogEntry(
|
||||||
|
job_id="test",
|
||||||
|
id="layer1_1234",
|
||||||
|
status="Downloading",
|
||||||
|
progress_detail=PullProgressDetail(current=25000, total=50000),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# layer1 is 50% of total that needs pulling (50KB out of 100KB)
|
||||||
|
# At 50% download = 35% layer progress (70% * 50%)
|
||||||
|
# Size-weighted: 50% * 35% = 17.5%
|
||||||
|
assert progress.calculate_progress() == pytest.approx(17.5)
|
||||||
|
|
||||||
|
# Complete layer1
|
||||||
|
progress.process_event(
|
||||||
|
PullLogEntry(
|
||||||
|
job_id="test",
|
||||||
|
id="layer1_1234",
|
||||||
|
status="Pull complete",
|
||||||
|
progress_detail=PullProgressDetail(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# layer1 at 100%, layer2 at 0%
|
||||||
|
# Size-weighted: 50% * 100% + 50% * 0% = 50%
|
||||||
|
assert progress.calculate_progress() == pytest.approx(50.0)
|
||||||
|
|
||||||
|
def test_fallback_to_count_based_without_manifest(self):
|
||||||
|
"""Test that without manifest, count-based progress is used."""
|
||||||
|
progress = ImagePullProgress()
|
||||||
|
|
||||||
|
# No manifest set - should use count-based progress
|
||||||
|
|
||||||
|
# Two layers of different sizes
|
||||||
|
progress.process_event(
|
||||||
|
PullLogEntry(
|
||||||
|
job_id="test",
|
||||||
|
id="small",
|
||||||
|
status="Pulling fs layer",
|
||||||
|
progress_detail=PullProgressDetail(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
progress.process_event(
|
||||||
|
PullLogEntry(
|
||||||
|
job_id="test",
|
||||||
|
id="large",
|
||||||
|
status="Pulling fs layer",
|
||||||
|
progress_detail=PullProgressDetail(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Small layer (1KB) completes
|
||||||
|
progress.process_event(
|
||||||
|
PullLogEntry(
|
||||||
|
job_id="test",
|
||||||
|
id="small",
|
||||||
|
status="Downloading",
|
||||||
|
progress_detail=PullProgressDetail(current=1000, total=1000),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
progress.process_event(
|
||||||
|
PullLogEntry(
|
||||||
|
job_id="test",
|
||||||
|
id="small",
|
||||||
|
status="Pull complete",
|
||||||
|
progress_detail=PullProgressDetail(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Large layer (100MB) at 1%
|
||||||
|
progress.process_event(
|
||||||
|
PullLogEntry(
|
||||||
|
job_id="test",
|
||||||
|
id="large",
|
||||||
|
status="Downloading",
|
||||||
|
progress_detail=PullProgressDetail(current=1000000, total=100000000),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count-based: each layer is 50% weight
|
||||||
|
# small: 100% * 50% = 50%
|
||||||
|
# large: 0.7% (1% * 70%) * 50% = 0.35%
|
||||||
|
# Total: ~50.35%
|
||||||
|
assert progress.calculate_progress() == pytest.approx(50.35, rel=0.01)
|
||||||
|
|||||||
Reference in New Issue
Block a user