diff --git a/api/client.go b/api/client.go index 3c47ac39b..2b8bde1be 100644 --- a/api/client.go +++ b/api/client.go @@ -367,7 +367,11 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd // CreateBlob creates a blob from a file on the server. digest is the // expected SHA256 digest of the file, and r represents the file. -func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error { +func (c *Client) CreateBlob(ctx context.Context, digest string, local bool, r io.Reader) error { + headers := make(http.Header) + if local { + headers.Set("X-Redirect-Create", "1") + } return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil) } diff --git a/cmd/cmd.go b/cmd/cmd.go index 542283e75..932182157 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "crypto/sha256" + "encoding/json" "errors" "fmt" "io" @@ -12,6 +13,7 @@ import ( "math" "net" "net/http" + "net/url" "os" "os/signal" "path/filepath" @@ -286,8 +288,9 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er // Resolve server to IP // Check if server is local - if client.IsLocal() { - config, err := client.ServerConfig(cmd.Context()) + /* if client.IsLocal() { + digest = strings.ReplaceAll(digest, ":", "-") + config, err := client.HeadBlob(cmd.Context(), digest) if err != nil { return "", err } @@ -295,42 +298,106 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er modelDir := config.ModelDir // Get blob destination - digest = strings.ReplaceAll(digest, ":", "-") + dest := filepath.Join(modelDir, "blobs", digest) err = createBlobLocal(path, dest) if err == nil { return digest, nil } + } */ + if client.IsLocal() { + config, err := getLocalPath(cmd.Context(), digest) + if err != nil { + return "", err + } + + if config == nil { + fmt.Println("config is nil") + return digest, nil + } + + fmt.Println("HI") + dest := config.ModelDir + fmt.Println("dest is ", dest) + err = createBlobLocal(path, dest) + if err == nil { + fmt.Println("createlocalblob succeed") + return digest, nil + } + fmt.Println("err is ", err) + fmt.Println("createlocalblob faileds") } - if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil { + fmt.Println("DEFAULT") + if err = client.CreateBlob(cmd.Context(), digest, false, bin); err != nil { return "", err } return digest, nil } +func getLocalPath(ctx context.Context, digest string) (*api.ServerConfig, error) { + ollamaHost := envconfig.Host + + client := http.DefaultClient + base := &url.URL{ + Scheme: ollamaHost.Scheme, + Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port), + } + + var reqBody io.Reader + var respData api.ServerConfig + data, err := json.Marshal(digest) + if err != nil { + return nil, err + } + + reqBody = bytes.NewReader(data) + path := fmt.Sprintf("/api/blobs/%s", digest) + requestURL := base.JoinPath(path) + request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), reqBody) + if err != nil { + return nil, err + } + + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Accept", "application/json") + request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) + request.Header.Set("X-Redirect-Create", "1") + + fmt.Println("request", request) + resp, err := client.Do(request) + if err != nil { + return nil, err + } + defer resp.Body.Close() + fmt.Println("made it here") + fmt.Println("resp", resp) + + if resp.StatusCode == http.StatusTemporaryRedirect { + fmt.Println("redirect") + if err := json.Unmarshal([]byte(resp.Header.Get("loc")), &respData); err != nil { + fmt.Println("error unmarshalling response data") + return nil, err + } + } + + fmt.Println("!!!!!!!!!!") + fmt.Println(respData) + return &respData, nil +} + func createBlobLocal(path string, dest string) error { // This function should be called if the server is local // It should find the model directory, copy the blob over, and return the digest dirPath := filepath.Dir(dest) + fmt.Println("dirpath is ", dirPath) if err := os.MkdirAll(dirPath, 0o755); err != nil { + fmt.Println("failed to create directory") return err } - // Check blob exists - _, err := os.Stat(dest) - switch { - case errors.Is(err, os.ErrNotExist): - // noop - case err != nil: - return err - default: - // blob already exists - return nil - } - // Copy blob over sourceFile, err := os.Open(path) if err != nil { diff --git a/server/routes.go b/server/routes.go index faf6ad6f1..b9e7c1c85 100644 --- a/server/routes.go +++ b/server/routes.go @@ -940,6 +940,23 @@ func (s *Server) CreateBlobHandler(c *gin.Context) { c.Status(http.StatusOK) return } + fmt.Println("HEIAHOEIHFOAHAEFHAO") + fmt.Println(c.GetHeader("X-Redirect-Create")) + if c.GetHeader("X-Redirect-Create") == "1" { + response := api.ServerConfig{ModelDir: path} + fmt.Println("Hit redirect") + resp, err := json.Marshal(response) + fmt.Println("marshalled response") + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.Header("loc", string(resp)) + fmt.Println("!!!!!!!!!", string(resp)) + c.Status(http.StatusTemporaryRedirect) + return + } layer, err := NewLayer(c.Request.Body, "") if err != nil {