From 33848ad10faad7cb1041b74520155cec52ad8de0 Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Tue, 23 Jul 2024 12:26:05 -0700 Subject: [PATCH] serverside copy --- cmd/cmd.go | 53 ++++++++++--------------- {cmd => server}/copy_darwin.go | 2 +- {cmd => server}/copy_linux.go | 2 +- {cmd => server}/copy_windows.go | 2 +- server/routes.go | 69 +++++++++++++++++++++++++++++++-- 5 files changed, 88 insertions(+), 40 deletions(-) rename {cmd => server}/copy_darwin.go (95%) rename {cmd => server}/copy_linux.go (89%) rename {cmd => server}/copy_windows.go (98%) diff --git a/cmd/cmd.go b/cmd/cmd.go index 06bc13332..ca5d7b7be 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -111,7 +111,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { path = tempfile } - digest, err := createBlob(cmd, client, path) + digest, err := createBlob(cmd, path) if err != nil { return err } @@ -264,7 +264,7 @@ func tempZipFiles(path string) (string, error) { var ErrBlobExists = errors.New("blob exists") -func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) { +func createBlob(cmd *cobra.Command, path string) (string, error) { bin, err := os.Open(path) if err != nil { return "", err @@ -286,36 +286,20 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er // If we can, we return the path to the directory // If we can't, we return an error // If the blob exists already, we return the digest - dest, err := getLocalPath(cmd.Context(), digest) + err = CreateBlob(cmd.Context(), path, digest) if errors.Is(err, ErrBlobExists) { return digest, nil } - // Successfully found the model directory - if err == nil { - // Copy blob in via OS specific copy - // Linux errors out to use io.copy - err = localCopy(path, dest) - if err == nil { - return digest, nil - } - - // Default copy using io.copy - err = defaultCopy(path, dest) - if err == nil { - return digest, nil - } - } - - // If at any point copying the blob over locally fails, we default to the copy through the server - if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil { + if err != nil { return "", err } + return digest, nil } -func getLocalPath(ctx context.Context, digest string) (string, error) { +func CreateBlob(ctx context.Context, src, digest string) (error) { ollamaHost := envconfig.Host client := http.DefaultClient @@ -326,7 +310,7 @@ func getLocalPath(ctx context.Context, digest string) (string, error) { data, err := json.Marshal(digest) if err != nil { - return "", err + return err } reqBody := bytes.NewReader(data) @@ -334,33 +318,36 @@ func getLocalPath(ctx context.Context, digest string) (string, error) { requestURL := base.JoinPath(path) request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), reqBody) if err != nil { - return "", err + return err } authz, err := api.Authorization(ctx, request) if err != nil { - return "", err + return err } request.Header.Set("Authorization", authz) request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) - request.Header.Set("X-Redirect-Create", "1") + request.Header.Set("X-Ollama-File", src) resp, err := client.Do(request) if err != nil { - return "", err + return err } defer resp.Body.Close() - if resp.StatusCode == http.StatusTemporaryRedirect { - dest := resp.Header.Get("LocalLocation") - - return dest, nil + if resp.StatusCode == http.StatusCreated { + return nil } - return "", ErrBlobExists + + if resp.StatusCode == http.StatusOK { + return ErrBlobExists + } + + return err } -func defaultCopy(path string, dest string) error { +func DefaultCopy(path string, dest string) error { // This function should be called if the server is local // It should find the model directory, copy the blob over, and return the digest dirPath := filepath.Dir(dest) diff --git a/cmd/copy_darwin.go b/server/copy_darwin.go similarity index 95% rename from cmd/copy_darwin.go rename to server/copy_darwin.go index fab84d170..0631970dd 100644 --- a/cmd/copy_darwin.go +++ b/server/copy_darwin.go @@ -1,4 +1,4 @@ -package cmd +package server import ( "os" diff --git a/cmd/copy_linux.go b/server/copy_linux.go similarity index 89% rename from cmd/copy_linux.go rename to server/copy_linux.go index 29978da08..4b9407f45 100644 --- a/cmd/copy_linux.go +++ b/server/copy_linux.go @@ -1,4 +1,4 @@ -package cmd +package server import "errors" diff --git a/cmd/copy_windows.go b/server/copy_windows.go similarity index 98% rename from cmd/copy_windows.go rename to server/copy_windows.go index 76a985a1d..4afd44526 100644 --- a/cmd/copy_windows.go +++ b/server/copy_windows.go @@ -1,7 +1,7 @@ //go:build windows // +build windows -package cmd +package server import ( "os" diff --git a/server/routes.go b/server/routes.go index e601710f4..6ef1cb202 100644 --- a/server/routes.go +++ b/server/routes.go @@ -942,10 +942,13 @@ func (s *Server) CreateBlobHandler(c *gin.Context) { c.Status(http.StatusOK) return } - if c.GetHeader("X-Redirect-Create") == "1" && s.isLocal(c) { - c.Header("LocalLocation", path) - c.Status(http.StatusTemporaryRedirect) - return + + if c.GetHeader("X-Ollama-File") != "" && s.isLocal(c) { + err = localBlobCopy(c.GetHeader("X-Ollama-File"), path) + if err == nil { + c.Status(http.StatusCreated) + return + } } layer, err := NewLayer(c.Request.Body, "") @@ -962,6 +965,29 @@ func (s *Server) CreateBlobHandler(c *gin.Context) { c.Status(http.StatusCreated) } +func localBlobCopy (src, dest string) error { + _, err := os.Stat(src) + switch { + case errors.Is(err, os.ErrNotExist): + return err + case err != nil: + return err + default: + } + + err = localCopy(src, dest) + if err == nil { + return nil + } + + err = defaultCopy(src, dest) + if err == nil { + return nil + } + + return err +} + func (s *Server) isLocal(c *gin.Context) bool { if authz := c.GetHeader("Authorization"); authz != "" { parts := strings.Split(authz, ":") @@ -1010,6 +1036,41 @@ func (s *Server) isLocal(c *gin.Context) bool { return false } +func defaultCopy(path string, dest string) error { + // This function should be called if the server is local + // It should find the model directory, copy the blob over, and return the digest + dirPath := filepath.Dir(dest) + + if err := os.MkdirAll(dirPath, 0o755); err != nil { + return err + } + + // Copy blob over + sourceFile, err := os.Open(path) + if err != nil { + return fmt.Errorf("could not open source file: %v", err) + } + defer sourceFile.Close() + + destFile, err := os.Create(dest) + if err != nil { + return fmt.Errorf("could not create destination file: %v", err) + } + defer destFile.Close() + + _, err = io.CopyBuffer(destFile, sourceFile, make([]byte, 4*1024*1024)) + if err != nil { + return fmt.Errorf("error copying file: %v", err) + } + + err = destFile.Sync() + if err != nil { + return fmt.Errorf("error flushing file: %v", err) + } + + return nil +} + func isLocalIP(ip netip.Addr) bool { if interfaces, err := net.Interfaces(); err == nil { for _, iface := range interfaces {