diff --git a/go.mod b/go.mod index 2e0c6614c..79f9cf77e 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( ) require ( + github.com/Masterminds/semver/v3 v3.2.1 // indirect github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/chewxy/hm v1.0.0 // indirect diff --git a/go.sum b/go.sum index 926ed26d8..cc6625f30 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7 gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= +github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8= github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo= github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= diff --git a/parser/parser.go b/parser/parser.go index 7f566da4e..45f79b485 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -9,6 +9,7 @@ import ( "strconv" "strings" + "github.com/Masterminds/semver/v3" "golang.org/x/text/encoding/unicode" "golang.org/x/text/transform" ) @@ -41,6 +42,8 @@ func (c Command) String() string { case "message": role, message, _ := strings.Cut(c.Args, ": ") fmt.Fprintf(&sb, "MESSAGE %s %s", role, quote(message)) + case "ollama": + fmt.Fprintf(&sb, "OLLAMA %s", quote(c.Args)) default: fmt.Fprintf(&sb, "PARAMETER %s %s", c.Name, quote(c.Args)) } @@ -57,12 +60,14 @@ const ( stateParameter stateMessage stateComment + stateVersion ) var ( errMissingFrom = errors.New("no FROM line") errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"") errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"") + errInvalidVersion = errors.New("invalid version") ) func ParseFile(r io.Reader) (*File, error) { @@ -109,6 +114,9 @@ func ParseFile(r io.Reader) (*File, error) { case "message": // transition to stateMessage which validates the message role next = stateMessage + cmd.Name = s + case "ollama": + next = stateVersion fallthrough default: cmd.Name = s @@ -123,6 +131,23 @@ func ParseFile(r io.Reader) (*File, error) { role = b.String() case stateComment, stateNil: // pass + case stateVersion: + s, ok := unquote(strings.TrimSpace(b.String())) + if !ok { + if _, err := b.WriteRune(r); err != nil { + return nil, err + } + + continue + } else if isSpace(r){ + return nil, errInvalidVersion + } else if _, err := semver.NewVersion(s); err != nil { + return nil, errInvalidVersion + } + + cmd.Args = s + f.Commands = append(f.Commands, cmd) + case stateValue: s, ok := unquote(strings.TrimSpace(b.String())) if !ok || isSpace(r) { @@ -157,6 +182,16 @@ func ParseFile(r io.Reader) (*File, error) { switch curr { case stateComment, stateNil: // pass; nothing to flush + case stateVersion: + s, ok := unquote(strings.TrimSpace(b.String())) + if !ok { + return nil, io.ErrUnexpectedEOF + } else if _, err := semver.NewVersion(s); err != nil { + return nil, errInvalidVersion + } + + cmd.Args = s + f.Commands = append(f.Commands, cmd) case stateValue: s, ok := unquote(strings.TrimSpace(b.String())) if !ok { @@ -236,6 +271,15 @@ func parseRuneForState(r rune, cs state) (state, rune, error) { default: return stateComment, 0, nil } + case stateVersion: + switch { + case isNewline(r), isSpace(r): + return stateNil, 0, nil + case isAlpha(r), isNumber(r), r == '.': + return stateVersion, r, nil + default: + return stateNil, r, nil + } default: return stateNil, 0, errors.New("") } @@ -296,7 +340,7 @@ func isValidMessageRole(role string) bool { func isValidCommand(cmd string) bool { switch strings.ToLower(cmd) { - case "from", "license", "template", "system", "adapter", "parameter", "message": + case "from", "license", "template", "system", "adapter", "parameter", "message", "ollama": return true default: return false diff --git a/server/images.go b/server/images.go index 836dbcc2d..8f043769b 100644 --- a/server/images.go +++ b/server/images.go @@ -374,6 +374,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio } var messages []*api.Message + var version string parameters := make(map[string]any) var layers []*Layer @@ -529,6 +530,8 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio } messages = append(messages, &api.Message{Role: role, Content: content}) + case "ollama": + version = c.Args default: ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}}) if err != nil { @@ -545,7 +548,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio } } } - } + } var err2 error layers = slices.DeleteFunc(layers, func(layer *Layer) bool { @@ -642,7 +645,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio old, _ := ParseNamedManifest(name) fn(api.ProgressResponse{Status: "writing manifest"}) - if err := WriteManifest(name, layer, layers); err != nil { + if err := WriteManifest(name, layer, layers, version); err != nil { return err } diff --git a/server/manifest.go b/server/manifest.go index 726bb48d8..9104e257b 100644 --- a/server/manifest.go +++ b/server/manifest.go @@ -18,6 +18,7 @@ type Manifest struct { MediaType string `json:"mediaType"` Config *Layer `json:"config"` Layers []*Layer `json:"layers"` + Ollama string `json:"ollama"` filepath string fi os.FileInfo @@ -93,7 +94,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) { return &m, nil } -func WriteManifest(name model.Name, config *Layer, layers []*Layer) error { +func WriteManifest(name model.Name, config *Layer, layers []*Layer, ollama string) error { manifests, err := GetManifestPath() if err != nil { return err @@ -115,6 +116,7 @@ func WriteManifest(name model.Name, config *Layer, layers []*Layer) error { MediaType: "application/vnd.docker.distribution.manifest.v2+json", Config: config, Layers: layers, + Ollama: ollama, } return json.NewEncoder(f).Encode(m) diff --git a/server/routes_create_test.go b/server/routes_create_test.go index e801a74f5..c967f7c31 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -623,3 +623,34 @@ func TestCreateDetectTemplate(t *testing.T) { }) }) } + +func TestCreateVersion(t *testing.T){ + gin.SetMode(gin.TestMode) + + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + envconfig.LoadConfig() + var s Server + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nOLLAMA 0.2.3\nLICENSE MIT\nLICENSE Apache-2.0", createBinFile(t, nil, nil)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + t.Run("invalid version", func(t *testing.T) { + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s\nOLLAMA 0..400", createBinFile(t, nil, nil)), + Stream: &stream, + }) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status code 400, actual %d", w.Code) + } + }) +} \ No newline at end of file diff --git a/server/routes_delete_test.go b/server/routes_delete_test.go index 33a97a73d..0c0b6c7df 100644 --- a/server/routes_delete_test.go +++ b/server/routes_delete_test.go @@ -99,7 +99,7 @@ func TestDeleteDuplicateLayers(t *testing.T) { } // create a manifest with duplicate layers - if err := WriteManifest(n, config, []*Layer{config}); err != nil { + if err := WriteManifest(n, config, []*Layer{config}, ""); err != nil { t.Fatal(err) }