testing and FROM local version copy

This commit is contained in:
Josh Yan 2024-07-29 14:30:03 -07:00
parent ab9dfbddea
commit aaa1c08a5d
3 changed files with 97 additions and 13 deletions

View File

@ -385,7 +385,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
case "model", "adapter": case "model", "adapter":
var baseLayers []*layerGGML var baseLayers []*layerGGML
if name := model.ParseName(c.Args); name.IsValid() { if name := model.ParseName(c.Args); name.IsValid() {
baseLayers, err = parseFromModel(ctx, name, fn) baseLayers, version, err = parseFromModel(ctx, name, fn)
if err != nil { if err != nil {
return err return err
} }
@ -531,7 +531,9 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
messages = append(messages, &api.Message{Role: role, Content: content}) messages = append(messages, &api.Message{Role: role, Content: content})
case "ollama": case "ollama":
if version == "" {
version = c.Args version = c.Args
}
default: default:
ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}}) ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}})
if err != nil { if err != nil {

View File

@ -30,26 +30,27 @@ type layerGGML struct {
*llm.GGML *llm.GGML
} }
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) { func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, version string, err error) {
m, err := ParseNamedManifest(name) m, err := ParseNamedManifest(name)
switch { switch {
case errors.Is(err, os.ErrNotExist): case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil { if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
return nil, err return nil, version, err
} }
m, err = ParseNamedManifest(name) m, err = ParseNamedManifest(name)
if err != nil { if err != nil {
return nil, err return nil, version, err
} }
case err != nil: case err != nil:
return nil, err return nil, version, err
} }
version = m.Ollama
for _, layer := range m.Layers { for _, layer := range m.Layers {
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest()) layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
if err != nil { if err != nil {
return nil, err return nil, version, err
} }
switch layer.MediaType { switch layer.MediaType {
@ -58,18 +59,18 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
"application/vnd.ollama.image.adapter": "application/vnd.ollama.image.adapter":
blobpath, err := GetBlobsPath(layer.Digest) blobpath, err := GetBlobsPath(layer.Digest)
if err != nil { if err != nil {
return nil, err return nil, version, err
} }
blob, err := os.Open(blobpath) blob, err := os.Open(blobpath)
if err != nil { if err != nil {
return nil, err return nil, version, err
} }
defer blob.Close() defer blob.Close()
ggml, _, err := llm.DecodeGGML(blob, 0) ggml, _, err := llm.DecodeGGML(blob, 0)
if err != nil { if err != nil {
return nil, err return nil, version, err
} }
layers = append(layers, &layerGGML{layer, ggml}) layers = append(layers, &layerGGML{layer, ggml})
@ -78,7 +79,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
} }
} }
return layers, nil return layers, version, nil
} }
func extractFromZipFile(p string, file *os.File, fn func(api.ProgressResponse)) error { func extractFromZipFile(p string, file *os.File, fn func(api.ProgressResponse)) error {

View File

@ -642,9 +642,59 @@ func TestCreateVersion(t *testing.T){
t.Fatalf("expected status code 200, actual %d", w.Code) t.Fatalf("expected status code 200, actual %d", w.Code)
} }
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
f, err := os.Open(filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"))
if err != nil {
t.Fatal(err)
}
bts := json.NewDecoder(f)
var m Manifest
if err := bts.Decode(&m); err != nil {
t.Fatal(err)
}
if m.Ollama != "0.2.3" {
t.Errorf("got %s != want 0.2.3", m.Ollama)
}
t.Run("no version", func(t *testing.T) {
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "noversion",
Modelfile: fmt.Sprintf("FROM %s\nLICENSE MIT\nLICENSE Apache-2.0", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "noversion", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "noversion", "latest"),
})
f, err := os.Open(filepath.Join(p, "manifests", "registry.ollama.ai", "library", "noversion", "latest"))
if err != nil {
t.Fatal(err)
}
bts := json.NewDecoder(f)
var m Manifest
if err := bts.Decode(&m); err != nil {
t.Fatal(err)
}
if m.Ollama != "" {
t.Errorf("got %s != want \"\"", m.Ollama)
}
})
t.Run("invalid version", func(t *testing.T) { t.Run("invalid version", func(t *testing.T) {
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test", Name: "invalid",
Modelfile: fmt.Sprintf("FROM %s\nOLLAMA 0..400", createBinFile(t, nil, nil)), Modelfile: fmt.Sprintf("FROM %s\nOLLAMA 0..400", createBinFile(t, nil, nil)),
Stream: &stream, Stream: &stream,
}) })
@ -653,4 +703,35 @@ func TestCreateVersion(t *testing.T){
t.Fatalf("expected status code 400, actual %d", w.Code) t.Fatalf("expected status code 400, actual %d", w.Code)
} }
}) })
t.Run("from valid version", func(t *testing.T) {
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "fromvalid",
Modelfile: "FROM test",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "fromvalid", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "fromvalid", "latest"),
})
f, err := os.Open(filepath.Join(p, "manifests", "registry.ollama.ai", "library", "fromvalid", "latest"))
if err != nil {
t.Fatal(err)
}
bts := json.NewDecoder(f)
var m Manifest
if err := bts.Decode(&m); err != nil {
t.Fatal(err)
}
if m.Ollama != "0.2.3" {
t.Errorf("got %s != want 0.2.3", m.Ollama)
}
})
} }