Compare commits
1 Commits
main
...
bmizerany/
Author | SHA1 | Date | |
---|---|---|---|
![]() |
b48b6f85cd |
@ -27,6 +27,7 @@ import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@ -73,19 +74,22 @@ const (
|
||||
DefaultMaxChunkSize = 8 << 20
|
||||
)
|
||||
|
||||
// DefaultCache returns a new disk cache for storing models. If the
|
||||
// OLLAMA_MODELS environment variable is set, it uses that directory;
|
||||
// otherwise, it uses $HOME/.ollama/models.
|
||||
func DefaultCache() (*blob.DiskCache, error) {
|
||||
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
|
||||
dir := os.Getenv("OLLAMA_MODELS")
|
||||
if dir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
home = cmp.Or(home, ".")
|
||||
dir = filepath.Join(home, ".ollama", "models")
|
||||
}
|
||||
return blob.Open(dir)
|
||||
})
|
||||
|
||||
// DefaultCache returns the default cache used by the registry. It is
|
||||
// configured from the OLLAMA_MODELS environment variable, or defaults to
|
||||
// $HOME/.ollama/models, or, if an error occurs obtaining the home directory,
|
||||
// it uses the current working directory.
|
||||
func DefaultCache() (*blob.DiskCache, error) {
|
||||
return defaultCache()
|
||||
}
|
||||
|
||||
// Error is the standard error returned by Ollama APIs. It can represent a
|
||||
@ -168,6 +172,10 @@ func CompleteName(name string) string {
|
||||
// Registry is a client for performing push and pull operations against an
|
||||
// Ollama registry.
|
||||
type Registry struct {
|
||||
// Cache is the cache used to store models. If nil, [DefaultCache] is
|
||||
// used.
|
||||
Cache *blob.DiskCache
|
||||
|
||||
// UserAgent is the User-Agent header to send with requests to the
|
||||
// registry. If empty, the User-Agent is determined by HTTPClient.
|
||||
UserAgent string
|
||||
@ -206,12 +214,18 @@ type Registry struct {
|
||||
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
||||
MaxChunkSize int64
|
||||
|
||||
// Mask, if set, is the name used to convert non-fully qualified
|
||||
// names to fully qualified names. If empty, the default mask
|
||||
// ("registry.ollama.ai/library/_:latest") is used.
|
||||
// Mask, if set, is the name used to convert non-fully qualified names
|
||||
// to fully qualified names. If empty, [DefaultMask] is used.
|
||||
Mask string
|
||||
}
|
||||
|
||||
func (r *Registry) cache() (*blob.DiskCache, error) {
|
||||
if r.Cache != nil {
|
||||
return r.Cache, nil
|
||||
}
|
||||
return defaultCache()
|
||||
}
|
||||
|
||||
func (r *Registry) parseName(name string) (names.Name, error) {
|
||||
mask := defaultMask
|
||||
if r.Mask != "" {
|
||||
@ -241,6 +255,10 @@ func DefaultRegistry() (*Registry, error) {
|
||||
}
|
||||
|
||||
var rc Registry
|
||||
rc.Cache, err = defaultCache()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -282,12 +300,17 @@ type PushParams struct {
|
||||
}
|
||||
|
||||
// Push pushes the model with the name in the cache to the remote registry.
|
||||
func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
|
||||
func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
||||
if p == nil {
|
||||
p = &PushParams{}
|
||||
}
|
||||
|
||||
m, err := r.ResolveLocal(c, cmp.Or(p.From, name))
|
||||
c, err := r.cache()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m, err := r.ResolveLocal(cmp.Or(p.From, name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -403,7 +426,7 @@ func canRetry(err error) bool {
|
||||
// chunks of the specified size, and then reassembled and verified. This is
|
||||
// typically slower than splitting the model up across layers, and is mostly
|
||||
// utilized for layers of type equal to "application/vnd.ollama.image".
|
||||
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
|
||||
func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
scheme, n, _, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -417,6 +440,11 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
|
||||
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
||||
}
|
||||
|
||||
c, err := r.cache()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
exists := func(l *Layer) bool {
|
||||
info, err := c.Get(l.Digest)
|
||||
return err == nil && info.Size == l.Size
|
||||
@ -554,11 +582,15 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
|
||||
|
||||
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
|
||||
// before attempting to unlink the model.
|
||||
func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
|
||||
func (r *Registry) Unlink(name string) (ok bool, _ error) {
|
||||
n, err := r.parseName(name)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
c, err := r.cache()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return c.Unlink(n.String())
|
||||
}
|
||||
|
||||
@ -631,12 +663,17 @@ type Layer struct {
|
||||
}
|
||||
|
||||
// ResolveLocal resolves a name to a Manifest in the local cache.
|
||||
func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
|
||||
func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
|
||||
_, n, d, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := r.cache()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !d.IsValid() {
|
||||
// No digest, so resolve the manifest by name.
|
||||
d, err = c.Resolve(n.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -73,6 +73,7 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
|
||||
// To simulate a network error, pass a handler that returns a 499 status code.
|
||||
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||
t.Helper()
|
||||
|
||||
c, err := blob.Open(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -86,6 +87,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||
}
|
||||
|
||||
r := &Registry{
|
||||
Cache: c,
|
||||
HTTPClient: &http.Client{
|
||||
Transport: recordRoundTripper(h),
|
||||
},
|
||||
@ -152,55 +154,55 @@ func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
|
||||
}
|
||||
|
||||
func TestPushZero(t *testing.T) {
|
||||
rc, c := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), c, "empty", nil)
|
||||
rc, _ := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), "empty", nil)
|
||||
if !errors.Is(err, ErrManifestInvalid) {
|
||||
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushSingle(t *testing.T) {
|
||||
rc, c := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), c, "single", nil)
|
||||
rc, _ := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), "single", nil)
|
||||
testutil.Check(t, err)
|
||||
}
|
||||
|
||||
func TestPushMultiple(t *testing.T) {
|
||||
rc, c := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), c, "multiple", nil)
|
||||
rc, _ := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), "multiple", nil)
|
||||
testutil.Check(t, err)
|
||||
}
|
||||
|
||||
func TestPushNotFound(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Errorf("unexpected request: %v", r)
|
||||
})
|
||||
err := rc.Push(t.Context(), c, "notfound", nil)
|
||||
err := rc.Push(t.Context(), "notfound", nil)
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushNullLayer(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
err := rc.Push(t.Context(), c, "null", nil)
|
||||
rc, _ := newClient(t, nil)
|
||||
err := rc.Push(t.Context(), "null", nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushSizeMismatch(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
rc, _ := newClient(t, nil)
|
||||
ctx, _ := withTraceUnexpected(t.Context())
|
||||
got := rc.Push(ctx, c, "sizemismatch", nil)
|
||||
got := rc.Push(ctx, "sizemismatch", nil)
|
||||
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
|
||||
t.Errorf("err = %v; want size mismatch", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushInvalid(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
err := rc.Push(t.Context(), c, "invalid", nil)
|
||||
rc, _ := newClient(t, nil)
|
||||
err := rc.Push(t.Context(), "invalid", nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
@ -208,7 +210,7 @@ func TestPushInvalid(t *testing.T) {
|
||||
|
||||
func TestPushExistsAtRemote(t *testing.T) {
|
||||
var pushed bool
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/uploads/") {
|
||||
if !pushed {
|
||||
// First push. Return an uploadURL.
|
||||
@ -236,35 +238,35 @@ func TestPushExistsAtRemote(t *testing.T) {
|
||||
|
||||
check := testutil.Checker(t)
|
||||
|
||||
err := rc.Push(ctx, c, "single", nil)
|
||||
err := rc.Push(ctx, "single", nil)
|
||||
check(err)
|
||||
|
||||
if !errors.Is(errors.Join(errs...), nil) {
|
||||
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
|
||||
}
|
||||
|
||||
err = rc.Push(ctx, c, "single", nil)
|
||||
err = rc.Push(ctx, "single", nil)
|
||||
check(err)
|
||||
}
|
||||
|
||||
func TestPushRemoteError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
w.WriteHeader(500)
|
||||
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
|
||||
return
|
||||
}
|
||||
})
|
||||
got := rc.Push(t.Context(), c, "single", nil)
|
||||
got := rc.Push(t.Context(), "single", nil)
|
||||
checkErrCode(t, got, 500, "blob_error")
|
||||
}
|
||||
|
||||
func TestPushLocationError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Location", ":///x")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
})
|
||||
got := rc.Push(t.Context(), c, "single", nil)
|
||||
got := rc.Push(t.Context(), "single", nil)
|
||||
wantContains := "invalid upload URL"
|
||||
if got == nil || !strings.Contains(got.Error(), wantContains) {
|
||||
t.Errorf("err = %v; want to contain %v", got, wantContains)
|
||||
@ -272,14 +274,14 @@ func TestPushLocationError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPushUploadRoundtripError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Host == "blob.store" {
|
||||
w.WriteHeader(499) // force RoundTrip error on upload
|
||||
return
|
||||
}
|
||||
w.Header().Set("Location", "http://blob.store/blobs/123")
|
||||
})
|
||||
got := rc.Push(t.Context(), c, "single", nil)
|
||||
got := rc.Push(t.Context(), "single", nil)
|
||||
if !errors.Is(got, errRoundTrip) {
|
||||
t.Errorf("got = %v; want %v", got, errRoundTrip)
|
||||
}
|
||||
@ -295,20 +297,20 @@ func TestPushUploadFileOpenError(t *testing.T) {
|
||||
os.Remove(c.GetFile(l.Digest))
|
||||
},
|
||||
})
|
||||
got := rc.Push(ctx, c, "single", nil)
|
||||
got := rc.Push(ctx, "single", nil)
|
||||
if !errors.Is(got, fs.ErrNotExist) {
|
||||
t.Errorf("got = %v; want fs.ErrNotExist", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushCommitRoundtripError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
panic("unexpected")
|
||||
}
|
||||
w.WriteHeader(499) // force RoundTrip error
|
||||
})
|
||||
err := rc.Push(t.Context(), c, "zero", nil)
|
||||
err := rc.Push(t.Context(), "zero", nil)
|
||||
if !errors.Is(err, errRoundTrip) {
|
||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
||||
}
|
||||
@ -322,8 +324,8 @@ func checkNotExist(t *testing.T, err error) {
|
||||
}
|
||||
|
||||
func TestRegistryPullInvalidName(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
err := rc.Pull(t.Context(), c, "://")
|
||||
rc, _ := newClient(t, nil)
|
||||
err := rc.Pull(t.Context(), "://")
|
||||
if !errors.Is(err, ErrNameInvalid) {
|
||||
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
||||
}
|
||||
@ -338,10 +340,10 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, resp := range cases {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, resp)
|
||||
})
|
||||
err := rc.Pull(t.Context(), c, "x")
|
||||
err := rc.Pull(t.Context(), "x")
|
||||
if !errors.Is(err, ErrManifestInvalid) {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
@ -364,18 +366,18 @@ func TestRegistryPullNotCached(t *testing.T) {
|
||||
})
|
||||
|
||||
// Confirm that the layer does not exist locally
|
||||
_, err := rc.ResolveLocal(c, "model")
|
||||
_, err := rc.ResolveLocal("model")
|
||||
checkNotExist(t, err)
|
||||
|
||||
_, err = c.Get(d)
|
||||
checkNotExist(t, err)
|
||||
|
||||
err = rc.Pull(t.Context(), c, "model")
|
||||
err = rc.Pull(t.Context(), "model")
|
||||
check(err)
|
||||
|
||||
mw, err := rc.Resolve(t.Context(), "model")
|
||||
check(err)
|
||||
mg, err := rc.ResolveLocal(c, "model")
|
||||
mg, err := rc.ResolveLocal("model")
|
||||
check(err)
|
||||
if !reflect.DeepEqual(mw, mg) {
|
||||
t.Errorf("mw = %v; mg = %v", mw, mg)
|
||||
@ -400,7 +402,7 @@ func TestRegistryPullNotCached(t *testing.T) {
|
||||
|
||||
func TestRegistryPullCached(t *testing.T) {
|
||||
cached := blob.DigestFromBytes("exists")
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
w.WriteHeader(499) // should not be called
|
||||
return
|
||||
@ -423,7 +425,7 @@ func TestRegistryPullCached(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := rc.Pull(ctx, c, "single")
|
||||
err := rc.Pull(ctx, "single")
|
||||
testutil.Check(t, err)
|
||||
|
||||
want := []int64{6}
|
||||
@ -436,30 +438,30 @@ func TestRegistryPullCached(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRegistryPullManifestNotFound(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
})
|
||||
err := rc.Pull(t.Context(), c, "notfound")
|
||||
err := rc.Pull(t.Context(), "notfound")
|
||||
checkErrCode(t, err, 404, "")
|
||||
}
|
||||
|
||||
func TestRegistryPullResolveRemoteError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
|
||||
})
|
||||
err := rc.Pull(t.Context(), c, "single")
|
||||
err := rc.Pull(t.Context(), "single")
|
||||
checkErrCode(t, err, 500, "an_error")
|
||||
}
|
||||
|
||||
func TestRegistryPullResolveRoundtripError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/manifests/") {
|
||||
w.WriteHeader(499) // force RoundTrip error
|
||||
return
|
||||
}
|
||||
})
|
||||
err := rc.Pull(t.Context(), c, "single")
|
||||
err := rc.Pull(t.Context(), "single")
|
||||
if !errors.Is(err, errRoundTrip) {
|
||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
||||
}
|
||||
@ -512,7 +514,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
||||
|
||||
// Check that we pull all layers that we can.
|
||||
|
||||
err := rc.Pull(ctx, c, "mixed")
|
||||
err := rc.Pull(ctx, "mixed")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -530,7 +532,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRegistryPullChunking(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
|
||||
if r.URL.Host != "blob.store" {
|
||||
// The production registry redirects to the blob store.
|
||||
@ -568,7 +570,7 @@ func TestRegistryPullChunking(t *testing.T) {
|
||||
},
|
||||
})
|
||||
|
||||
err := rc.Pull(ctx, c, "remote")
|
||||
err := rc.Pull(ctx, "remote")
|
||||
testutil.Check(t, err)
|
||||
|
||||
want := []int64{0, 3, 6}
|
||||
@ -785,27 +787,27 @@ func TestParseNameExtended(t *testing.T) {
|
||||
|
||||
func TestUnlink(t *testing.T) {
|
||||
t.Run("found by name", func(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
rc, _ := newClient(t, nil)
|
||||
|
||||
// confirm linked
|
||||
_, err := rc.ResolveLocal(c, "single")
|
||||
_, err := rc.ResolveLocal("single")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// unlink
|
||||
_, err = rc.Unlink(c, "single")
|
||||
_, err = rc.Unlink("single")
|
||||
testutil.Check(t, err)
|
||||
|
||||
// confirm unlinked
|
||||
_, err = rc.ResolveLocal(c, "single")
|
||||
_, err = rc.ResolveLocal("single")
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Errorf("err = %v; want fs.ErrNotExist", err)
|
||||
}
|
||||
})
|
||||
t.Run("not found by name", func(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
ok, err := rc.Unlink(c, "manifestNotFound")
|
||||
rc, _ := newClient(t, nil)
|
||||
ok, err := rc.Unlink("manifestNotFound")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -6,6 +6,9 @@ import (
|
||||
|
||||
// Trace is a set of functions that are called to report progress during blob
|
||||
// downloads and uploads.
|
||||
//
|
||||
// Use [WithTrace] to attach a Trace to a context for use with [Registry.Push]
|
||||
// and [Registry.Pull].
|
||||
type Trace struct {
|
||||
// Update is called during [Registry.Push] and [Registry.Pull] to
|
||||
// report the progress of blob uploads and downloads.
|
||||
|
@ -63,25 +63,28 @@ func main() {
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
c, err := ollama.DefaultCache()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err = func() error {
|
||||
err := func() error {
|
||||
switch cmd := flag.Arg(0); cmd {
|
||||
case "pull":
|
||||
return cmdPull(ctx, rc, c)
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return cmdPull(ctx, rc)
|
||||
case "push":
|
||||
return cmdPush(ctx, rc, c)
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return cmdPush(ctx, rc)
|
||||
case "import":
|
||||
c, err := ollama.DefaultCache()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return cmdImport(ctx, c)
|
||||
default:
|
||||
if cmd == "" {
|
||||
@ -99,7 +102,7 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
|
||||
func cmdPull(ctx context.Context, rc *ollama.Registry) error {
|
||||
model := flag.Arg(1)
|
||||
if model == "" {
|
||||
flag.Usage()
|
||||
@ -145,7 +148,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
||||
|
||||
errc := make(chan error)
|
||||
go func() {
|
||||
errc <- rc.Pull(ctx, c, model)
|
||||
errc <- rc.Pull(ctx, model)
|
||||
}()
|
||||
|
||||
t := time.NewTicker(time.Second)
|
||||
@ -161,7 +164,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
||||
}
|
||||
}
|
||||
|
||||
func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
|
||||
func cmdPush(ctx context.Context, rc *ollama.Registry) error {
|
||||
args := flag.Args()[1:]
|
||||
flag := flag.NewFlagSet("push", flag.ExitOnError)
|
||||
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
|
||||
@ -177,7 +180,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
||||
}
|
||||
|
||||
from := cmp.Or(*flagFrom, model)
|
||||
m, err := rc.ResolveLocal(c, from)
|
||||
m, err := rc.ResolveLocal(from)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -203,7 +206,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
||||
},
|
||||
})
|
||||
|
||||
return rc.Push(ctx, c, model, &ollama.PushParams{
|
||||
return rc.Push(ctx, model, &ollama.PushParams{
|
||||
From: from,
|
||||
})
|
||||
}
|
||||
|
@ -11,7 +11,6 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
)
|
||||
|
||||
@ -27,7 +26,6 @@ import (
|
||||
// directly to the blob disk cache.
|
||||
type Local struct {
|
||||
Client *ollama.Registry // required
|
||||
Cache *blob.DiskCache // required
|
||||
Logger *slog.Logger // required
|
||||
|
||||
// Fallback, if set, is used to handle requests that are not handled by
|
||||
@ -199,7 +197,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ok, err := s.Client.Unlink(s.Cache, p.model())
|
||||
ok, err := s.Client.Unlink(p.model())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -42,10 +42,10 @@ func newTestServer(t *testing.T) *Local {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rc := &ollama.Registry{
|
||||
Cache: c,
|
||||
HTTPClient: panicOnRoundTrip,
|
||||
}
|
||||
l := &Local{
|
||||
Cache: c,
|
||||
Client: rc,
|
||||
Logger: testutil.Slogger(t),
|
||||
}
|
||||
@ -87,7 +87,7 @@ func TestServerDelete(t *testing.T) {
|
||||
|
||||
s := newTestServer(t)
|
||||
|
||||
_, err := s.Client.ResolveLocal(s.Cache, "smol")
|
||||
_, err := s.Client.ResolveLocal("smol")
|
||||
check(err)
|
||||
|
||||
got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
|
||||
@ -95,7 +95,7 @@ func TestServerDelete(t *testing.T) {
|
||||
t.Fatalf("Code = %d; want 200", got.Code)
|
||||
}
|
||||
|
||||
_, err = s.Client.ResolveLocal(s.Cache, "smol")
|
||||
_, err = s.Client.ResolveLocal("smol")
|
||||
if err == nil {
|
||||
t.Fatal("expected smol to have been deleted")
|
||||
}
|
||||
|
@ -34,7 +34,6 @@ import (
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/model/models/mllama"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/registry"
|
||||
"github.com/ollama/ollama/template"
|
||||
@ -1129,7 +1128,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Handler, error) {
|
||||
func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
corsConfig := cors.DefaultConfig()
|
||||
corsConfig.AllowWildcard = true
|
||||
corsConfig.AllowBrowserExtensions = true
|
||||
@ -1197,7 +1196,6 @@ func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Ha
|
||||
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
Cache: c,
|
||||
Client: rc,
|
||||
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
|
||||
Fallback: r,
|
||||
@ -1258,16 +1256,12 @@ func Serve(ln net.Listener) error {
|
||||
|
||||
s := &Server{addr: ln.Addr()}
|
||||
|
||||
c, err := ollama.DefaultCache()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h, err := s.GenerateRoutes(c, rc)
|
||||
h, err := s.GenerateRoutes(rc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -23,7 +23,6 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
@ -490,11 +489,6 @@ func TestRoutes(t *testing.T) {
|
||||
modelsDir := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", modelsDir)
|
||||
|
||||
c, err := blob.Open(modelsDir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open models dir: %v", err)
|
||||
}
|
||||
|
||||
rc := &ollama.Registry{
|
||||
// This is a temporary measure to allow us to move forward,
|
||||
// surfacing any code contacting ollama.com we do not intended
|
||||
@ -511,7 +505,7 @@ func TestRoutes(t *testing.T) {
|
||||
}
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(c, rc)
|
||||
router, err := s.GenerateRoutes(rc)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate routes: %v", err)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user