From 6e2be5a8a09efca05698dfe6fd9d5933d1c29295 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Fri, 14 Jul 2023 16:28:55 -0700 Subject: [PATCH] add create, pull, and push --- api/client.go | 43 +++++++++++++++++++------ cmd/cmd.go | 81 +++++++++++++++++++++++++++++++++++++++++++++++- server/routes.go | 2 +- 3 files changed, 114 insertions(+), 12 deletions(-) diff --git a/api/client.go b/api/client.go index b58e53a96..f19f6c5e4 100644 --- a/api/client.go +++ b/api/client.go @@ -107,15 +107,38 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate type PullProgressFunc func(PullProgress) error func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error { - /* - return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error { - var resp PullProgress - if err := json.Unmarshal(bts, &resp); err != nil { - return err - } + return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error { + var resp PullProgress + if err := json.Unmarshal(bts, &resp); err != nil { + return err + } - return fn(resp) - }) - */ - return nil + return fn(resp) + }) +} + +type PushProgressFunc func(PushProgress) error + +func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error { + return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error { + var resp PushProgress + if err := json.Unmarshal(bts, &resp); err != nil { + return err + } + + return fn(resp) + }) +} + +type CreateProgressFunc func(CreateProgress) error + +func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error { + return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error { + var resp CreateProgress + if err := json.Unmarshal(bts, &resp); err != nil { + return err + } + + return fn(resp) + }) } diff --git a/cmd/cmd.go b/cmd/cmd.go index a761dd04e..23b6eda6f 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -30,6 +30,23 @@ func cacheDir() string { return filepath.Join(home, ".ollama") } +func create(cmd *cobra.Command, args []string) error { + filename, _ := cmd.Flags().GetString("file") + client := api.NewClient() + + request := api.CreateRequest{Name: args[0], Path: filename} + fn := func(resp api.CreateProgress) error { + fmt.Println(resp.Status) + return nil + } + + if err := client.Create(context.Background(), &request, fn); err != nil { + return err + } + + return nil +} + func RunRun(cmd *cobra.Command, args []string) error { _, err := os.Stat(args[0]) switch { @@ -51,8 +68,37 @@ func RunRun(cmd *cobra.Command, args []string) error { return RunGenerate(cmd, args) } +func push(cmd *cobra.Command, args []string) error { + client := api.NewClient() + + request := api.PushRequest{Name: args[0]} + fn := func(resp api.PushProgress) error { + fmt.Println(resp.Status) + return nil + } + + if err := client.Push(context.Background(), &request, fn); err != nil { + return err + } + return nil +} + +func RunPull(cmd *cobra.Command, args []string) error { + return pull(args[0]) +} + func pull(model string) error { - // TODO add this back + client := api.NewClient() + + request := api.PullRequest{Name: model} + fn := func(resp api.PullProgress) error { + fmt.Println(resp.Status) + return nil + } + + if err := client.Pull(context.Background(), &request, fn); err != nil { + return err + } return nil } @@ -199,6 +245,15 @@ func NewCLI() *cobra.Command { cobra.EnableCommandSorting = false + createCmd := &cobra.Command{ + Use: "create MODEL", + Short: "Create a model from a Modelfile", + Args: cobra.MinimumNArgs(1), + RunE: create, + } + + createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")") + runCmd := &cobra.Command{ Use: "run MODEL [PROMPT]", Short: "Run a model", @@ -215,9 +270,33 @@ func NewCLI() *cobra.Command { RunE: RunServer, } + pullCmd := &cobra.Command{ + Use: "pull MODEL", + Short: "Pull a model from a registry", + Args: cobra.MinimumNArgs(1), + RunE: RunPull, + } + + pushCmd := &cobra.Command{ + Use: "push MODEL", + Short: "Push a model to a registry", + Args: cobra.MinimumNArgs(1), + RunE: push, + } + rootCmd.AddCommand( serveCmd, + createCmd, runCmd, + pullCmd, + pushCmd, + ) + + rootCmd.AddCommand( + serveCmd, + createCmd, + runCmd, + pullCmd, ) return rootCmd diff --git a/server/routes.go b/server/routes.go index 3035c46f4..eed5756a6 100644 --- a/server/routes.go +++ b/server/routes.go @@ -116,7 +116,7 @@ func pull(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - } + }() streamResponse(c, ch) }