serverside copy

This commit is contained in:
Josh Yan 2024-07-23 12:26:05 -07:00
parent ff06a2916d
commit 33848ad10f
5 changed files with 88 additions and 40 deletions

View File

@ -111,7 +111,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
path = tempfile path = tempfile
} }
digest, err := createBlob(cmd, client, path) digest, err := createBlob(cmd, path)
if err != nil { if err != nil {
return err return err
} }
@ -264,7 +264,7 @@ func tempZipFiles(path string) (string, error) {
var ErrBlobExists = errors.New("blob exists") var ErrBlobExists = errors.New("blob exists")
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) { func createBlob(cmd *cobra.Command, path string) (string, error) {
bin, err := os.Open(path) bin, err := os.Open(path)
if err != nil { if err != nil {
return "", err return "", err
@ -286,36 +286,20 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
// If we can, we return the path to the directory // If we can, we return the path to the directory
// If we can't, we return an error // If we can't, we return an error
// If the blob exists already, we return the digest // If the blob exists already, we return the digest
dest, err := getLocalPath(cmd.Context(), digest) err = CreateBlob(cmd.Context(), path, digest)
if errors.Is(err, ErrBlobExists) { if errors.Is(err, ErrBlobExists) {
return digest, nil return digest, nil
} }
// Successfully found the model directory if err != nil {
if err == nil {
// Copy blob in via OS specific copy
// Linux errors out to use io.copy
err = localCopy(path, dest)
if err == nil {
return digest, nil
}
// Default copy using io.copy
err = defaultCopy(path, dest)
if err == nil {
return digest, nil
}
}
// If at any point copying the blob over locally fails, we default to the copy through the server
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
return "", err return "", err
} }
return digest, nil return digest, nil
} }
func getLocalPath(ctx context.Context, digest string) (string, error) { func CreateBlob(ctx context.Context, src, digest string) (error) {
ollamaHost := envconfig.Host ollamaHost := envconfig.Host
client := http.DefaultClient client := http.DefaultClient
@ -326,7 +310,7 @@ func getLocalPath(ctx context.Context, digest string) (string, error) {
data, err := json.Marshal(digest) data, err := json.Marshal(digest)
if err != nil { if err != nil {
return "", err return err
} }
reqBody := bytes.NewReader(data) reqBody := bytes.NewReader(data)
@ -334,33 +318,36 @@ func getLocalPath(ctx context.Context, digest string) (string, error) {
requestURL := base.JoinPath(path) requestURL := base.JoinPath(path)
request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), reqBody) request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), reqBody)
if err != nil { if err != nil {
return "", err return err
} }
authz, err := api.Authorization(ctx, request) authz, err := api.Authorization(ctx, request)
if err != nil { if err != nil {
return "", err return err
} }
request.Header.Set("Authorization", authz) request.Header.Set("Authorization", authz)
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
request.Header.Set("X-Redirect-Create", "1") request.Header.Set("X-Ollama-File", src)
resp, err := client.Do(request) resp, err := client.Do(request)
if err != nil { if err != nil {
return "", err return err
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode == http.StatusTemporaryRedirect { if resp.StatusCode == http.StatusCreated {
dest := resp.Header.Get("LocalLocation") return nil
return dest, nil
} }
return "", ErrBlobExists
if resp.StatusCode == http.StatusOK {
return ErrBlobExists
}
return err
} }
func defaultCopy(path string, dest string) error { func DefaultCopy(path string, dest string) error {
// This function should be called if the server is local // This function should be called if the server is local
// It should find the model directory, copy the blob over, and return the digest // It should find the model directory, copy the blob over, and return the digest
dirPath := filepath.Dir(dest) dirPath := filepath.Dir(dest)

View File

@ -1,4 +1,4 @@
package cmd package server
import ( import (
"os" "os"

View File

@ -1,4 +1,4 @@
package cmd package server
import "errors" import "errors"

View File

@ -1,7 +1,7 @@
//go:build windows //go:build windows
// +build windows // +build windows
package cmd package server
import ( import (
"os" "os"

View File

@ -942,10 +942,13 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
c.Status(http.StatusOK) c.Status(http.StatusOK)
return return
} }
if c.GetHeader("X-Redirect-Create") == "1" && s.isLocal(c) {
c.Header("LocalLocation", path) if c.GetHeader("X-Ollama-File") != "" && s.isLocal(c) {
c.Status(http.StatusTemporaryRedirect) err = localBlobCopy(c.GetHeader("X-Ollama-File"), path)
return if err == nil {
c.Status(http.StatusCreated)
return
}
} }
layer, err := NewLayer(c.Request.Body, "") layer, err := NewLayer(c.Request.Body, "")
@ -962,6 +965,29 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
c.Status(http.StatusCreated) c.Status(http.StatusCreated)
} }
func localBlobCopy (src, dest string) error {
_, err := os.Stat(src)
switch {
case errors.Is(err, os.ErrNotExist):
return err
case err != nil:
return err
default:
}
err = localCopy(src, dest)
if err == nil {
return nil
}
err = defaultCopy(src, dest)
if err == nil {
return nil
}
return err
}
func (s *Server) isLocal(c *gin.Context) bool { func (s *Server) isLocal(c *gin.Context) bool {
if authz := c.GetHeader("Authorization"); authz != "" { if authz := c.GetHeader("Authorization"); authz != "" {
parts := strings.Split(authz, ":") parts := strings.Split(authz, ":")
@ -1010,6 +1036,41 @@ func (s *Server) isLocal(c *gin.Context) bool {
return false return false
} }
func defaultCopy(path string, dest string) error {
// This function should be called if the server is local
// It should find the model directory, copy the blob over, and return the digest
dirPath := filepath.Dir(dest)
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return err
}
// Copy blob over
sourceFile, err := os.Open(path)
if err != nil {
return fmt.Errorf("could not open source file: %v", err)
}
defer sourceFile.Close()
destFile, err := os.Create(dest)
if err != nil {
return fmt.Errorf("could not create destination file: %v", err)
}
defer destFile.Close()
_, err = io.CopyBuffer(destFile, sourceFile, make([]byte, 4*1024*1024))
if err != nil {
return fmt.Errorf("error copying file: %v", err)
}
err = destFile.Sync()
if err != nil {
return fmt.Errorf("error flushing file: %v", err)
}
return nil
}
func isLocalIP(ip netip.Addr) bool { func isLocalIP(ip netip.Addr) bool {
if interfaces, err := net.Interfaces(); err == nil { if interfaces, err := net.Interfaces(); err == nil {
for _, iface := range interfaces { for _, iface := range interfaces {