From aaa1c08a5dab70fd04628b1ab897b1ae2767d1b8 Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Mon, 29 Jul 2024 14:30:03 -0700 Subject: [PATCH] testing and FROM local version copy --- server/images.go | 6 ++- server/model.go | 19 ++++---- server/routes_create_test.go | 85 +++++++++++++++++++++++++++++++++++- 3 files changed, 97 insertions(+), 13 deletions(-) diff --git a/server/images.go b/server/images.go index 8f043769b..295a60768 100644 --- a/server/images.go +++ b/server/images.go @@ -385,7 +385,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio case "model", "adapter": var baseLayers []*layerGGML if name := model.ParseName(c.Args); name.IsValid() { - baseLayers, err = parseFromModel(ctx, name, fn) + baseLayers, version, err = parseFromModel(ctx, name, fn) if err != nil { return err } @@ -531,7 +531,9 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio messages = append(messages, &api.Message{Role: role, Content: content}) case "ollama": - version = c.Args + if version == "" { + version = c.Args + } default: ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}}) if err != nil { diff --git a/server/model.go b/server/model.go index c6d3078f1..39922700d 100644 --- a/server/model.go +++ b/server/model.go @@ -30,26 +30,27 @@ type layerGGML struct { *llm.GGML } -func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) { +func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, version string, err error) { m, err := ParseNamedManifest(name) switch { case errors.Is(err, os.ErrNotExist): if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil { - return nil, err + return nil, version, err } m, err = ParseNamedManifest(name) if err != nil { - return nil, err + return nil, version, err } case err != nil: - return nil, err + return nil, version, err } + version = m.Ollama for _, layer := range m.Layers { layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest()) if err != nil { - return nil, err + return nil, version, err } switch layer.MediaType { @@ -58,18 +59,18 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe "application/vnd.ollama.image.adapter": blobpath, err := GetBlobsPath(layer.Digest) if err != nil { - return nil, err + return nil, version, err } blob, err := os.Open(blobpath) if err != nil { - return nil, err + return nil, version, err } defer blob.Close() ggml, _, err := llm.DecodeGGML(blob, 0) if err != nil { - return nil, err + return nil, version, err } layers = append(layers, &layerGGML{layer, ggml}) @@ -78,7 +79,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe } } - return layers, nil + return layers, version, nil } func extractFromZipFile(p string, file *os.File, fn func(api.ProgressResponse)) error { diff --git a/server/routes_create_test.go b/server/routes_create_test.go index c967f7c31..284798e5f 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -632,7 +632,7 @@ func TestCreateVersion(t *testing.T){ envconfig.LoadConfig() var s Server - w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ Name: "test", Modelfile: fmt.Sprintf("FROM %s\nOLLAMA 0.2.3\nLICENSE MIT\nLICENSE Apache-2.0", createBinFile(t, nil, nil)), Stream: &stream, @@ -642,9 +642,59 @@ func TestCreateVersion(t *testing.T){ t.Fatalf("expected status code 200, actual %d", w.Code) } + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + f, err := os.Open(filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest")) + if err != nil { + t.Fatal(err) + } + bts := json.NewDecoder(f) + + var m Manifest + if err := bts.Decode(&m); err != nil { + t.Fatal(err) + } + + if m.Ollama != "0.2.3" { + t.Errorf("got %s != want 0.2.3", m.Ollama) + } + + t.Run("no version", func(t *testing.T) { + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "noversion", + Modelfile: fmt.Sprintf("FROM %s\nLICENSE MIT\nLICENSE Apache-2.0", createBinFile(t, nil, nil)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "noversion", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "noversion", "latest"), + }) + + f, err := os.Open(filepath.Join(p, "manifests", "registry.ollama.ai", "library", "noversion", "latest")) + if err != nil { + t.Fatal(err) + } + + bts := json.NewDecoder(f) + var m Manifest + if err := bts.Decode(&m); err != nil { + t.Fatal(err) + } + + if m.Ollama != "" { + t.Errorf("got %s != want \"\"", m.Ollama) + } + }) + t.Run("invalid version", func(t *testing.T) { w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ - Name: "test", + Name: "invalid", Modelfile: fmt.Sprintf("FROM %s\nOLLAMA 0..400", createBinFile(t, nil, nil)), Stream: &stream, }) @@ -653,4 +703,35 @@ func TestCreateVersion(t *testing.T){ t.Fatalf("expected status code 400, actual %d", w.Code) } }) + + t.Run("from valid version", func(t *testing.T) { + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "fromvalid", + Modelfile: "FROM test", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "fromvalid", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "fromvalid", "latest"), + }) + + f, err := os.Open(filepath.Join(p, "manifests", "registry.ollama.ai", "library", "fromvalid", "latest")) + if err != nil { + t.Fatal(err) + } + bts := json.NewDecoder(f) + + var m Manifest + if err := bts.Decode(&m); err != nil { + t.Fatal(err) + } + + if m.Ollama != "0.2.3" { + t.Errorf("got %s != want 0.2.3", m.Ollama) + } + }) } \ No newline at end of file