From 1d99451ad705478c0a22262ad38b5a403b61c291 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Thu, 17 Apr 2025 12:43:09 -0700 Subject: [PATCH] server/internal/client/ollama: handle some network errors gracefully (#10317) --- server/internal/client/ollama/registry.go | 52 +++++++++++++++---- .../client/ollama/registry_synctest_test.go | 51 ++++++++++++++++++ .../internal/client/ollama/registry_test.go | 4 +- server/internal/registry/server.go | 24 +++++++-- 4 files changed, 115 insertions(+), 16 deletions(-) create mode 100644 server/internal/client/ollama/registry_synctest_test.go diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index 4d00b41e1..18c7b70be 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -223,8 +223,21 @@ type Registry struct { ChunkingThreshold int64 // Mask, if set, is the name used to convert non-fully qualified names - // to fully qualified names. If empty, [DefaultMask] is used. + // to fully qualified names. + // If empty, [DefaultMask] is used. Mask string + + // ReadTimeout is the maximum duration for reading the entire request, + // including the body. + // A zero or negative value means there will be no timeout. + ReadTimeout time.Duration +} + +func (r *Registry) readTimeout() time.Duration { + if r.ReadTimeout > 0 { + return r.ReadTimeout + } + return 1<<63 - 1 // no timeout, max int } func (r *Registry) cache() (*blob.DiskCache, error) { @@ -248,8 +261,7 @@ func (r *Registry) parseName(name string) (names.Name, error) { // DefaultRegistry returns a new Registry configured from the environment. The // key is read from $HOME/.ollama/id_ed25519, MaxStreams is set to the -// value of OLLAMA_REGISTRY_MAXSTREAMS, and ChunkingDirectory is set to the -// system's temporary directory. +// value of OLLAMA_REGISTRY_MAXSTREAMS, and ReadTimeout is set to 30 seconds. // // It returns an error if any configuration in the environment is invalid. func DefaultRegistry() (*Registry, error) { @@ -263,6 +275,7 @@ func DefaultRegistry() (*Registry, error) { } var rc Registry + rc.ReadTimeout = 30 * time.Second rc.UserAgent = UserAgent() rc.Key, err = ssh.ParseRawPrivateKey(keyPEM) if err != nil { @@ -489,6 +502,12 @@ func (r *Registry) Pull(ctx context.Context, name string) error { for _, l := range layers { var received atomic.Int64 update := func(n int64, err error) { + if n == 0 && err == nil { + // Clients expect an update with no progress and no error to mean "starting download". + // This is not the case here, + // so we don't want to send an update in this case. + return + } completed.Add(n) t.update(l, received.Add(n), err) } @@ -562,6 +581,20 @@ func (r *Registry) Pull(ctx context.Context, name string) error { } }() + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(nil) + + timer := time.AfterFunc(r.readTimeout(), func() { + cancel(fmt.Errorf("%w: downloading %s %d-%d/%d", + context.DeadlineExceeded, + cs.Digest.Short(), + cs.Chunk.Start, + cs.Chunk.End, + l.Size, + )) + }) + defer timer.Stop() + req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil) if err != nil { return err @@ -574,8 +607,11 @@ func (r *Registry) Pull(ctx context.Context, name string) error { defer res.Body.Close() tr := &trackingReader{ - r: res.Body, - update: update, + r: res.Body, + update: func(n int64, err error) { + timer.Reset(r.readTimeout()) + update(n, err) + }, } if err := chunked.Put(cs.Chunk, cs.Digest, tr); err != nil { return err @@ -930,12 +966,6 @@ func (r *Registry) newRequest(ctx context.Context, method, url string, body io.R // is parsed from the response body and returned. If any other error occurs, it // is returned. func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("request error %s: %w", r.URL, err) - } - }() - if r.URL.Scheme == "https+insecure" { // TODO(bmizerany): clone client.Transport, set // InsecureSkipVerify, etc. diff --git a/server/internal/client/ollama/registry_synctest_test.go b/server/internal/client/ollama/registry_synctest_test.go new file mode 100644 index 000000000..2b4543375 --- /dev/null +++ b/server/internal/client/ollama/registry_synctest_test.go @@ -0,0 +1,51 @@ +// TODO: go:build goexperiment.synctest + +package ollama + +import ( + "context" + "errors" + "io" + "net/http" + "strings" + "testing" + "time" +) + +func TestPullDownloadTimeout(t *testing.T) { + rc, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) { + defer t.Log("upstream", r.Method, r.URL.Path) + switch { + case strings.HasPrefix(r.URL.Path, "/v2/library/smol/manifests/"): + io.WriteString(w, `{ + "layers": [{"digest": "sha256:1111111111111111111111111111111111111111111111111111111111111111", "size": 3}] + }`) + case strings.HasPrefix(r.URL.Path, "/v2/library/smol/blobs/sha256:1111111111111111111111111111111111111111111111111111111111111111"): + // Get headers out to client and then hang on the response + w.WriteHeader(200) + w.(http.Flusher).Flush() + + // Hang on the response and unblock when the client + // gives up + <-r.Context().Done() + default: + t.Fatalf("unexpected request: %s", r.URL.Path) + } + }) + rc.ReadTimeout = 100 * time.Millisecond + + done := make(chan error, 1) + go func() { + done <- rc.Pull(ctx, "http://example.com/library/smol") + }() + + select { + case err := <-done: + want := context.DeadlineExceeded + if !errors.Is(err, want) { + t.Errorf("err = %v, want %v", err, want) + } + case <-time.After(3 * time.Second): + t.Error("timeout waiting for Pull to finish") + } +} diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go index 8a3107356..474756725 100644 --- a/server/internal/client/ollama/registry_test.go +++ b/server/internal/client/ollama/registry_test.go @@ -605,8 +605,8 @@ func checkRequest(t *testing.T, req *http.Request, method, path string) { } } -func newRegistryClient(t *testing.T, h http.HandlerFunc) (*Registry, context.Context) { - s := httptest.NewServer(h) +func newRegistryClient(t *testing.T, upstream http.HandlerFunc) (*Registry, context.Context) { + s := httptest.NewServer(upstream) t.Cleanup(s.Close) cache, err := blob.Open(t.TempDir()) if err != nil { diff --git a/server/internal/registry/server.go b/server/internal/registry/server.go index 620948985..bd5f7dcd5 100644 --- a/server/internal/registry/server.go +++ b/server/internal/registry/server.go @@ -4,6 +4,7 @@ package registry import ( "cmp" + "context" "encoding/json" "errors" "fmt" @@ -11,6 +12,7 @@ import ( "log/slog" "net/http" "slices" + "strings" "sync" "time" @@ -330,9 +332,8 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error { return err } err := s.Client.Pull(ctx, p.model()) - var oe *ollama.Error - if errors.As(err, &oe) && oe.Temporary() { - continue // retry + if canRetry(err) { + continue } return err } @@ -390,3 +391,20 @@ func decodeUserJSON[T any](r io.Reader) (T, error) { } return zero, err } + +func canRetry(err error) bool { + if err == nil { + return false + } + var oe *ollama.Error + if errors.As(err, &oe) { + return oe.Temporary() + } + s := err.Error() + return cmp.Or( + errors.Is(err, context.DeadlineExceeded), + strings.Contains(s, "unreachable"), + strings.Contains(s, "no route to host"), + strings.Contains(s, "connection reset by peer"), + ) +}