server/internal/client/ollama: handle some network errors gracefully (#10317)
This commit is contained in:
parent
09bb2e30f6
commit
1d99451ad7
@ -223,8 +223,21 @@ type Registry struct {
|
|||||||
ChunkingThreshold int64
|
ChunkingThreshold int64
|
||||||
|
|
||||||
// Mask, if set, is the name used to convert non-fully qualified names
|
// Mask, if set, is the name used to convert non-fully qualified names
|
||||||
// to fully qualified names. If empty, [DefaultMask] is used.
|
// to fully qualified names.
|
||||||
|
// If empty, [DefaultMask] is used.
|
||||||
Mask string
|
Mask string
|
||||||
|
|
||||||
|
// ReadTimeout is the maximum duration for reading the entire request,
|
||||||
|
// including the body.
|
||||||
|
// A zero or negative value means there will be no timeout.
|
||||||
|
ReadTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Registry) readTimeout() time.Duration {
|
||||||
|
if r.ReadTimeout > 0 {
|
||||||
|
return r.ReadTimeout
|
||||||
|
}
|
||||||
|
return 1<<63 - 1 // no timeout, max int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Registry) cache() (*blob.DiskCache, error) {
|
func (r *Registry) cache() (*blob.DiskCache, error) {
|
||||||
@ -248,8 +261,7 @@ func (r *Registry) parseName(name string) (names.Name, error) {
|
|||||||
|
|
||||||
// DefaultRegistry returns a new Registry configured from the environment. The
|
// DefaultRegistry returns a new Registry configured from the environment. The
|
||||||
// key is read from $HOME/.ollama/id_ed25519, MaxStreams is set to the
|
// key is read from $HOME/.ollama/id_ed25519, MaxStreams is set to the
|
||||||
// value of OLLAMA_REGISTRY_MAXSTREAMS, and ChunkingDirectory is set to the
|
// value of OLLAMA_REGISTRY_MAXSTREAMS, and ReadTimeout is set to 30 seconds.
|
||||||
// system's temporary directory.
|
|
||||||
//
|
//
|
||||||
// It returns an error if any configuration in the environment is invalid.
|
// It returns an error if any configuration in the environment is invalid.
|
||||||
func DefaultRegistry() (*Registry, error) {
|
func DefaultRegistry() (*Registry, error) {
|
||||||
@ -263,6 +275,7 @@ func DefaultRegistry() (*Registry, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var rc Registry
|
var rc Registry
|
||||||
|
rc.ReadTimeout = 30 * time.Second
|
||||||
rc.UserAgent = UserAgent()
|
rc.UserAgent = UserAgent()
|
||||||
rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
|
rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -489,6 +502,12 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
for _, l := range layers {
|
for _, l := range layers {
|
||||||
var received atomic.Int64
|
var received atomic.Int64
|
||||||
update := func(n int64, err error) {
|
update := func(n int64, err error) {
|
||||||
|
if n == 0 && err == nil {
|
||||||
|
// Clients expect an update with no progress and no error to mean "starting download".
|
||||||
|
// This is not the case here,
|
||||||
|
// so we don't want to send an update in this case.
|
||||||
|
return
|
||||||
|
}
|
||||||
completed.Add(n)
|
completed.Add(n)
|
||||||
t.update(l, received.Add(n), err)
|
t.update(l, received.Add(n), err)
|
||||||
}
|
}
|
||||||
@ -562,6 +581,20 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancelCause(ctx)
|
||||||
|
defer cancel(nil)
|
||||||
|
|
||||||
|
timer := time.AfterFunc(r.readTimeout(), func() {
|
||||||
|
cancel(fmt.Errorf("%w: downloading %s %d-%d/%d",
|
||||||
|
context.DeadlineExceeded,
|
||||||
|
cs.Digest.Short(),
|
||||||
|
cs.Chunk.Start,
|
||||||
|
cs.Chunk.End,
|
||||||
|
l.Size,
|
||||||
|
))
|
||||||
|
})
|
||||||
|
defer timer.Stop()
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -574,8 +607,11 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
tr := &trackingReader{
|
tr := &trackingReader{
|
||||||
r: res.Body,
|
r: res.Body,
|
||||||
update: update,
|
update: func(n int64, err error) {
|
||||||
|
timer.Reset(r.readTimeout())
|
||||||
|
update(n, err)
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if err := chunked.Put(cs.Chunk, cs.Digest, tr); err != nil {
|
if err := chunked.Put(cs.Chunk, cs.Digest, tr); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -930,12 +966,6 @@ func (r *Registry) newRequest(ctx context.Context, method, url string, body io.R
|
|||||||
// is parsed from the response body and returned. If any other error occurs, it
|
// is parsed from the response body and returned. If any other error occurs, it
|
||||||
// is returned.
|
// is returned.
|
||||||
func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) {
|
func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) {
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("request error %s: %w", r.URL, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if r.URL.Scheme == "https+insecure" {
|
if r.URL.Scheme == "https+insecure" {
|
||||||
// TODO(bmizerany): clone client.Transport, set
|
// TODO(bmizerany): clone client.Transport, set
|
||||||
// InsecureSkipVerify, etc.
|
// InsecureSkipVerify, etc.
|
||||||
|
51
server/internal/client/ollama/registry_synctest_test.go
Normal file
51
server/internal/client/ollama/registry_synctest_test.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
// TODO: go:build goexperiment.synctest
|
||||||
|
|
||||||
|
package ollama
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPullDownloadTimeout(t *testing.T) {
|
||||||
|
rc, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer t.Log("upstream", r.Method, r.URL.Path)
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(r.URL.Path, "/v2/library/smol/manifests/"):
|
||||||
|
io.WriteString(w, `{
|
||||||
|
"layers": [{"digest": "sha256:1111111111111111111111111111111111111111111111111111111111111111", "size": 3}]
|
||||||
|
}`)
|
||||||
|
case strings.HasPrefix(r.URL.Path, "/v2/library/smol/blobs/sha256:1111111111111111111111111111111111111111111111111111111111111111"):
|
||||||
|
// Get headers out to client and then hang on the response
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.(http.Flusher).Flush()
|
||||||
|
|
||||||
|
// Hang on the response and unblock when the client
|
||||||
|
// gives up
|
||||||
|
<-r.Context().Done()
|
||||||
|
default:
|
||||||
|
t.Fatalf("unexpected request: %s", r.URL.Path)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
rc.ReadTimeout = 100 * time.Millisecond
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- rc.Pull(ctx, "http://example.com/library/smol")
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
want := context.DeadlineExceeded
|
||||||
|
if !errors.Is(err, want) {
|
||||||
|
t.Errorf("err = %v, want %v", err, want)
|
||||||
|
}
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Error("timeout waiting for Pull to finish")
|
||||||
|
}
|
||||||
|
}
|
@ -605,8 +605,8 @@ func checkRequest(t *testing.T, req *http.Request, method, path string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRegistryClient(t *testing.T, h http.HandlerFunc) (*Registry, context.Context) {
|
func newRegistryClient(t *testing.T, upstream http.HandlerFunc) (*Registry, context.Context) {
|
||||||
s := httptest.NewServer(h)
|
s := httptest.NewServer(upstream)
|
||||||
t.Cleanup(s.Close)
|
t.Cleanup(s.Close)
|
||||||
cache, err := blob.Open(t.TempDir())
|
cache, err := blob.Open(t.TempDir())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -4,6 +4,7 @@ package registry
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -11,6 +12,7 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -330,9 +332,8 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err := s.Client.Pull(ctx, p.model())
|
err := s.Client.Pull(ctx, p.model())
|
||||||
var oe *ollama.Error
|
if canRetry(err) {
|
||||||
if errors.As(err, &oe) && oe.Temporary() {
|
continue
|
||||||
continue // retry
|
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -390,3 +391,20 @@ func decodeUserJSON[T any](r io.Reader) (T, error) {
|
|||||||
}
|
}
|
||||||
return zero, err
|
return zero, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func canRetry(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var oe *ollama.Error
|
||||||
|
if errors.As(err, &oe) {
|
||||||
|
return oe.Temporary()
|
||||||
|
}
|
||||||
|
s := err.Error()
|
||||||
|
return cmp.Or(
|
||||||
|
errors.Is(err, context.DeadlineExceeded),
|
||||||
|
strings.Contains(s, "unreachable"),
|
||||||
|
strings.Contains(s, "no route to host"),
|
||||||
|
strings.Contains(s, "connection reset by peer"),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user