serverside copy

This commit is contained in:
Josh Yan 2024-07-23 12:26:05 -07:00
parent ff06a2916d
commit 33848ad10f
5 changed files with 88 additions and 40 deletions

View File

@ -111,7 +111,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
path = tempfile
}
digest, err := createBlob(cmd, client, path)
digest, err := createBlob(cmd, path)
if err != nil {
return err
}
@ -264,7 +264,7 @@ func tempZipFiles(path string) (string, error) {
var ErrBlobExists = errors.New("blob exists")
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
func createBlob(cmd *cobra.Command, path string) (string, error) {
bin, err := os.Open(path)
if err != nil {
return "", err
@ -286,36 +286,20 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
// If we can, we return the path to the directory
// If we can't, we return an error
// If the blob exists already, we return the digest
dest, err := getLocalPath(cmd.Context(), digest)
err = CreateBlob(cmd.Context(), path, digest)
if errors.Is(err, ErrBlobExists) {
return digest, nil
}
// Successfully found the model directory
if err == nil {
// Copy blob in via OS specific copy
// Linux errors out to use io.copy
err = localCopy(path, dest)
if err == nil {
return digest, nil
}
// Default copy using io.copy
err = defaultCopy(path, dest)
if err == nil {
return digest, nil
}
}
// If at any point copying the blob over locally fails, we default to the copy through the server
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
if err != nil {
return "", err
}
return digest, nil
}
func getLocalPath(ctx context.Context, digest string) (string, error) {
func CreateBlob(ctx context.Context, src, digest string) (error) {
ollamaHost := envconfig.Host
client := http.DefaultClient
@ -326,7 +310,7 @@ func getLocalPath(ctx context.Context, digest string) (string, error) {
data, err := json.Marshal(digest)
if err != nil {
return "", err
return err
}
reqBody := bytes.NewReader(data)
@ -334,33 +318,36 @@ func getLocalPath(ctx context.Context, digest string) (string, error) {
requestURL := base.JoinPath(path)
request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), reqBody)
if err != nil {
return "", err
return err
}
authz, err := api.Authorization(ctx, request)
if err != nil {
return "", err
return err
}
request.Header.Set("Authorization", authz)
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
request.Header.Set("X-Redirect-Create", "1")
request.Header.Set("X-Ollama-File", src)
resp, err := client.Do(request)
if err != nil {
return "", err
return err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusTemporaryRedirect {
dest := resp.Header.Get("LocalLocation")
return dest, nil
if resp.StatusCode == http.StatusCreated {
return nil
}
return "", ErrBlobExists
if resp.StatusCode == http.StatusOK {
return ErrBlobExists
}
return err
}
func defaultCopy(path string, dest string) error {
func DefaultCopy(path string, dest string) error {
// This function should be called if the server is local
// It should find the model directory, copy the blob over, and return the digest
dirPath := filepath.Dir(dest)

View File

@ -1,4 +1,4 @@
package cmd
package server
import (
"os"

View File

@ -1,4 +1,4 @@
package cmd
package server
import "errors"

View File

@ -1,7 +1,7 @@
//go:build windows
// +build windows
package cmd
package server
import (
"os"

View File

@ -942,10 +942,13 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
c.Status(http.StatusOK)
return
}
if c.GetHeader("X-Redirect-Create") == "1" && s.isLocal(c) {
c.Header("LocalLocation", path)
c.Status(http.StatusTemporaryRedirect)
return
if c.GetHeader("X-Ollama-File") != "" && s.isLocal(c) {
err = localBlobCopy(c.GetHeader("X-Ollama-File"), path)
if err == nil {
c.Status(http.StatusCreated)
return
}
}
layer, err := NewLayer(c.Request.Body, "")
@ -962,6 +965,29 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
c.Status(http.StatusCreated)
}
func localBlobCopy (src, dest string) error {
_, err := os.Stat(src)
switch {
case errors.Is(err, os.ErrNotExist):
return err
case err != nil:
return err
default:
}
err = localCopy(src, dest)
if err == nil {
return nil
}
err = defaultCopy(src, dest)
if err == nil {
return nil
}
return err
}
func (s *Server) isLocal(c *gin.Context) bool {
if authz := c.GetHeader("Authorization"); authz != "" {
parts := strings.Split(authz, ":")
@ -1010,6 +1036,41 @@ func (s *Server) isLocal(c *gin.Context) bool {
return false
}
func defaultCopy(path string, dest string) error {
// This function should be called if the server is local
// It should find the model directory, copy the blob over, and return the digest
dirPath := filepath.Dir(dest)
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return err
}
// Copy blob over
sourceFile, err := os.Open(path)
if err != nil {
return fmt.Errorf("could not open source file: %v", err)
}
defer sourceFile.Close()
destFile, err := os.Create(dest)
if err != nil {
return fmt.Errorf("could not create destination file: %v", err)
}
defer destFile.Close()
_, err = io.CopyBuffer(destFile, sourceFile, make([]byte, 4*1024*1024))
if err != nil {
return fmt.Errorf("error copying file: %v", err)
}
err = destFile.Sync()
if err != nil {
return fmt.Errorf("error flushing file: %v", err)
}
return nil
}
func isLocalIP(ip netip.Addr) bool {
if interfaces, err := net.Interfaces(); err == nil {
for _, iface := range interfaces {