server/internal/client/ollama: persist through chunk download errors (#9923)
This commit is contained in:
parent
00ebda8cc4
commit
c794fef2f2
@ -59,6 +59,11 @@ var (
|
|||||||
// ErrCached is passed to [Trace.PushUpdate] when a layer already
|
// ErrCached is passed to [Trace.PushUpdate] when a layer already
|
||||||
// exists. It is a non-fatal error and is never returned by [Registry.Push].
|
// exists. It is a non-fatal error and is never returned by [Registry.Push].
|
||||||
ErrCached = errors.New("cached")
|
ErrCached = errors.New("cached")
|
||||||
|
|
||||||
|
// ErrIncomplete is returned by [Registry.Pull] when a model pull was
|
||||||
|
// incomplete due to one or more layer download failures. Users that
|
||||||
|
// want specific errors should use [WithTrace].
|
||||||
|
ErrIncomplete = errors.New("incomplete")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defaults
|
// Defaults
|
||||||
@ -271,8 +276,19 @@ func DefaultRegistry() (*Registry, error) {
|
|||||||
|
|
||||||
func UserAgent() string {
|
func UserAgent() string {
|
||||||
buildinfo, _ := debug.ReadBuildInfo()
|
buildinfo, _ := debug.ReadBuildInfo()
|
||||||
|
|
||||||
|
version := buildinfo.Main.Version
|
||||||
|
if version == "(devel)" {
|
||||||
|
// When using `go run .` the version is "(devel)". This is seen
|
||||||
|
// as an invalid version by ollama.com and so it defaults to
|
||||||
|
// "needs upgrade" for some requests, such as pulls. These
|
||||||
|
// checks can be skipped by using the special version "v0.0.0",
|
||||||
|
// so we set it to that here.
|
||||||
|
version = "v0.0.0"
|
||||||
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
|
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
|
||||||
buildinfo.Main.Version,
|
version,
|
||||||
runtime.GOARCH,
|
runtime.GOARCH,
|
||||||
runtime.GOOS,
|
runtime.GOOS,
|
||||||
runtime.Version(),
|
runtime.Version(),
|
||||||
@ -418,13 +434,14 @@ func canRetry(err error) bool {
|
|||||||
//
|
//
|
||||||
// It always calls update with a nil error.
|
// It always calls update with a nil error.
|
||||||
type trackingReader struct {
|
type trackingReader struct {
|
||||||
|
l *Layer
|
||||||
r io.Reader
|
r io.Reader
|
||||||
n *atomic.Int64
|
update func(l *Layer, n int64, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
||||||
n, err = r.r.Read(p)
|
n, err = r.r.Read(p)
|
||||||
r.n.Add(int64(n))
|
r.update(r.l, int64(n), nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -462,16 +479,20 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
|
|
||||||
// Send initial layer trace events to allow clients to have an
|
// Send initial layer trace events to allow clients to have an
|
||||||
// understanding of work to be done before work starts.
|
// understanding of work to be done before work starts.
|
||||||
|
var expected int64
|
||||||
t := traceFromContext(ctx)
|
t := traceFromContext(ctx)
|
||||||
for _, l := range layers {
|
for _, l := range layers {
|
||||||
t.update(l, 0, nil)
|
t.update(l, 0, nil)
|
||||||
|
expected += l.Size
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var total atomic.Int64
|
||||||
var g errgroup.Group
|
var g errgroup.Group
|
||||||
g.SetLimit(r.maxStreams())
|
g.SetLimit(r.maxStreams())
|
||||||
for _, l := range layers {
|
for _, l := range layers {
|
||||||
info, err := c.Get(l.Digest)
|
info, err := c.Get(l.Digest)
|
||||||
if err == nil && info.Size == l.Size {
|
if err == nil && info.Size == l.Size {
|
||||||
|
total.Add(l.Size)
|
||||||
t.update(l, l.Size, ErrCached)
|
t.update(l, l.Size, ErrCached)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -484,21 +505,25 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
// TODO(bmizerany): fix this unbounded use of defer
|
// TODO(bmizerany): fix this unbounded use of defer
|
||||||
defer chunked.Close()
|
defer chunked.Close()
|
||||||
|
|
||||||
var progress atomic.Int64
|
|
||||||
for cs, err := range r.chunksums(ctx, name, l) {
|
for cs, err := range r.chunksums(ctx, name, l) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Bad chunksums response, update tracing
|
// Chunksum stream was interrupted, so tell
|
||||||
// clients and then bail.
|
// trace about it, and let in-flight chunk
|
||||||
t.update(l, progress.Load(), err)
|
// downloads finish. Once they finish, return
|
||||||
return err
|
// ErrIncomplete, which is triggered by the
|
||||||
|
// fact that the total bytes received is less
|
||||||
|
// than the expected bytes.
|
||||||
|
t.update(l, 0, err)
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
g.Go(func() (err error) {
|
g.Go(func() (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err == nil || errors.Is(err, ErrCached) {
|
||||||
|
total.Add(cs.Chunk.Size())
|
||||||
|
} else {
|
||||||
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
|
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
|
||||||
}
|
}
|
||||||
t.update(l, progress.Load(), err)
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
||||||
@ -522,7 +547,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
// download rate since it knows better than a
|
// download rate since it knows better than a
|
||||||
// client that is measuring rate based on
|
// client that is measuring rate based on
|
||||||
// wall-clock time-since-last-update.
|
// wall-clock time-since-last-update.
|
||||||
body := &trackingReader{r: res.Body, n: &progress}
|
body := &trackingReader{l: l, r: res.Body, update: t.update}
|
||||||
|
|
||||||
return chunked.Put(cs.Chunk, cs.Digest, body)
|
return chunked.Put(cs.Chunk, cs.Digest, body)
|
||||||
})
|
})
|
||||||
@ -531,6 +556,9 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
if err := g.Wait(); err != nil {
|
if err := g.Wait(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if total.Load() != expected {
|
||||||
|
return fmt.Errorf("%w: received %d/%d", ErrIncomplete, total.Load(), expected)
|
||||||
|
}
|
||||||
|
|
||||||
md := blob.DigestFromBytes(m.Data)
|
md := blob.DigestFromBytes(m.Data)
|
||||||
if err := blob.PutBytes(c, md, m.Data); err != nil {
|
if err := blob.PutBytes(c, md, m.Data); err != nil {
|
||||||
@ -757,15 +785,12 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
|
|||||||
}
|
}
|
||||||
blobURL := res.Header.Get("Content-Location")
|
blobURL := res.Header.Get("Content-Location")
|
||||||
|
|
||||||
var size int64
|
|
||||||
s := bufio.NewScanner(res.Body)
|
s := bufio.NewScanner(res.Body)
|
||||||
s.Split(bufio.ScanWords)
|
s.Split(bufio.ScanWords)
|
||||||
for {
|
for {
|
||||||
if !s.Scan() {
|
if !s.Scan() {
|
||||||
if s.Err() != nil {
|
if s.Err() != nil {
|
||||||
yield(chunksum{}, s.Err())
|
yield(chunksum{}, s.Err())
|
||||||
} else if size != l.Size {
|
|
||||||
yield(chunksum{}, fmt.Errorf("size mismatch: layer size %d != sum of chunks %d", size, l.Size))
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -789,12 +814,6 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
size += chunk.Size()
|
|
||||||
if size > l.Size {
|
|
||||||
yield(chunksum{}, fmt.Errorf("chunk size %d exceeds layer size %d", size, l.Size))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
cs := chunksum{
|
cs := chunksum{
|
||||||
URL: blobURL,
|
URL: blobURL,
|
||||||
Chunk: chunk,
|
Chunk: chunk,
|
||||||
|
@ -25,6 +25,28 @@ import (
|
|||||||
"github.com/ollama/ollama/server/internal/testutil"
|
"github.com/ollama/ollama/server/internal/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func ExampleRegistry_cancelOnFirstError() {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
ctx = WithTrace(ctx, &Trace{
|
||||||
|
Update: func(l *Layer, n int64, err error) {
|
||||||
|
if err != nil {
|
||||||
|
// Discontinue pulling layers if there is an
|
||||||
|
// error instead of continuing to pull more
|
||||||
|
// data.
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
var r Registry
|
||||||
|
if err := r.Pull(ctx, "model"); err != nil {
|
||||||
|
// panic for demo purposes
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestManifestMarshalJSON(t *testing.T) {
|
func TestManifestMarshalJSON(t *testing.T) {
|
||||||
// All manifests should contain an "empty" config object.
|
// All manifests should contain an "empty" config object.
|
||||||
var m Manifest
|
var m Manifest
|
||||||
@ -813,8 +835,13 @@ func TestPullChunksums(t *testing.T) {
|
|||||||
)
|
)
|
||||||
err := rc.Pull(ctx, "test")
|
err := rc.Pull(ctx, "test")
|
||||||
check(err)
|
check(err)
|
||||||
if !slices.Equal(reads, []int64{0, 3, 5}) {
|
wantReads := []int64{
|
||||||
t.Errorf("reads = %v; want %v", reads, []int64{0, 3, 5})
|
0, // initial signaling of layer pull starting
|
||||||
|
3, // first chunk read
|
||||||
|
2, // second chunk read
|
||||||
|
}
|
||||||
|
if !slices.Equal(reads, wantReads) {
|
||||||
|
t.Errorf("reads = %v; want %v", reads, wantReads)
|
||||||
}
|
}
|
||||||
|
|
||||||
mw, err := rc.Resolve(t.Context(), "test")
|
mw, err := rc.Resolve(t.Context(), "test")
|
||||||
|
@ -200,7 +200,7 @@ type params struct {
|
|||||||
//
|
//
|
||||||
// Unfortunately, this API was designed to be a bit awkward. Stream is
|
// Unfortunately, this API was designed to be a bit awkward. Stream is
|
||||||
// defined to default to true if not present, so we need a way to check
|
// defined to default to true if not present, so we need a way to check
|
||||||
// if the client decisively it to false. So, we use a pointer to a
|
// if the client decisively set it to false. So, we use a pointer to a
|
||||||
// bool. Gross.
|
// bool. Gross.
|
||||||
//
|
//
|
||||||
// Use [stream()] to get the correct value for this field.
|
// Use [stream()] to get the correct value for this field.
|
||||||
@ -280,17 +280,17 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|||||||
progress := make(map[*ollama.Layer]int64)
|
progress := make(map[*ollama.Layer]int64)
|
||||||
|
|
||||||
progressCopy := make(map[*ollama.Layer]int64, len(progress))
|
progressCopy := make(map[*ollama.Layer]int64, len(progress))
|
||||||
pushUpdate := func() {
|
flushProgress := func() {
|
||||||
defer maybeFlush()
|
defer maybeFlush()
|
||||||
|
|
||||||
// TODO(bmizerany): This scales poorly with more layers due to
|
// TODO(bmizerany): Flushing every layer in one update doesn't
|
||||||
// needing to flush out them all in one big update. We _could_
|
// scale well. We could flush only the modified layers or track
|
||||||
// just flush on the changed ones, or just track the whole
|
// the full download. Needs further consideration, though it's
|
||||||
// download. Needs more thought. This is fine for now.
|
// fine for now.
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
maps.Copy(progressCopy, progress)
|
maps.Copy(progressCopy, progress)
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
for l, n := range progress {
|
for l, n := range progressCopy {
|
||||||
enc.Encode(progressUpdateJSON{
|
enc.Encode(progressUpdateJSON{
|
||||||
Digest: l.Digest,
|
Digest: l.Digest,
|
||||||
Total: l.Size,
|
Total: l.Size,
|
||||||
@ -298,19 +298,26 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
defer flushProgress()
|
||||||
|
|
||||||
t := time.NewTicker(time.Hour) // "unstarted" timer
|
t := time.NewTicker(1000 * time.Hour) // "unstarted" timer
|
||||||
start := sync.OnceFunc(func() {
|
start := sync.OnceFunc(func() {
|
||||||
pushUpdate()
|
flushProgress() // flush initial state
|
||||||
t.Reset(100 * time.Millisecond)
|
t.Reset(100 * time.Millisecond)
|
||||||
})
|
})
|
||||||
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
||||||
Update: func(l *ollama.Layer, n int64, err error) {
|
Update: func(l *ollama.Layer, n int64, err error) {
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
start() // flush initial state
|
// Block flushing progress updates until every
|
||||||
|
// layer is accounted for. Clients depend on a
|
||||||
|
// complete model size to calculate progress
|
||||||
|
// correctly; if they use an incomplete total,
|
||||||
|
// progress indicators would erratically jump
|
||||||
|
// as new layers are registered.
|
||||||
|
start()
|
||||||
}
|
}
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
progress[l] = n
|
progress[l] += n
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@ -323,9 +330,9 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-t.C:
|
case <-t.C:
|
||||||
pushUpdate()
|
flushProgress()
|
||||||
case err := <-done:
|
case err := <-done:
|
||||||
pushUpdate()
|
flushProgress()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var status string
|
var status string
|
||||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
if errors.Is(err, ollama.ErrModelNotFound) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user