This commit is contained in:
Josh Yan 2024-07-03 16:31:53 -07:00
parent a6cfe7f00b
commit 64607d16a5
3 changed files with 105 additions and 17 deletions

View File

@ -358,7 +358,11 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
// CreateBlob creates a blob from a file on the server. digest is the
// expected SHA256 digest of the file, and r represents the file.
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
func (c *Client) CreateBlob(ctx context.Context, digest string, local bool, r io.Reader) error {
headers := make(http.Header)
if local {
headers.Set("X-Redirect-Create", "1")
}
return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
}

View File

@ -7,6 +7,7 @@ import (
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
@ -15,6 +16,7 @@ import (
"math"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"path/filepath"
@ -289,8 +291,9 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
// Resolve server to IP
// Check if server is local
if client.IsLocal() {
config, err := client.ServerConfig(cmd.Context())
/* if client.IsLocal() {
digest = strings.ReplaceAll(digest, ":", "-")
config, err := client.HeadBlob(cmd.Context(), digest)
if err != nil {
return "", err
}
@ -298,42 +301,106 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
modelDir := config.ModelDir
// Get blob destination
digest = strings.ReplaceAll(digest, ":", "-")
dest := filepath.Join(modelDir, "blobs", digest)
err = createBlobLocal(path, dest)
if err == nil {
return digest, nil
}
} */
if client.IsLocal() {
config, err := getLocalPath(cmd.Context(), digest)
if err != nil {
return "", err
}
if config == nil {
fmt.Println("config is nil")
return digest, nil
}
fmt.Println("HI")
dest := config.ModelDir
fmt.Println("dest is ", dest)
err = createBlobLocal(path, dest)
if err == nil {
fmt.Println("createlocalblob succeed")
return digest, nil
}
fmt.Println("err is ", err)
fmt.Println("createlocalblob faileds")
}
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
fmt.Println("DEFAULT")
if err = client.CreateBlob(cmd.Context(), digest, false, bin); err != nil {
return "", err
}
return digest, nil
}
func getLocalPath(ctx context.Context, digest string) (*api.ServerConfig, error) {
ollamaHost := envconfig.Host
client := http.DefaultClient
base := &url.URL{
Scheme: ollamaHost.Scheme,
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
}
var reqBody io.Reader
var respData api.ServerConfig
data, err := json.Marshal(digest)
if err != nil {
return nil, err
}
reqBody = bytes.NewReader(data)
path := fmt.Sprintf("/api/blobs/%s", digest)
requestURL := base.JoinPath(path)
request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), reqBody)
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/json")
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
request.Header.Set("X-Redirect-Create", "1")
fmt.Println("request", request)
resp, err := client.Do(request)
if err != nil {
return nil, err
}
defer resp.Body.Close()
fmt.Println("made it here")
fmt.Println("resp", resp)
if resp.StatusCode == http.StatusTemporaryRedirect {
fmt.Println("redirect")
if err := json.Unmarshal([]byte(resp.Header.Get("loc")), &respData); err != nil {
fmt.Println("error unmarshalling response data")
return nil, err
}
}
fmt.Println("!!!!!!!!!!")
fmt.Println(respData)
return &respData, nil
}
func createBlobLocal(path string, dest string) error {
// This function should be called if the server is local
// It should find the model directory, copy the blob over, and return the digest
dirPath := filepath.Dir(dest)
fmt.Println("dirpath is ", dirPath)
if err := os.MkdirAll(dirPath, 0o755); err != nil {
fmt.Println("failed to create directory")
return err
}
// Check blob exists
_, err := os.Stat(dest)
switch {
case errors.Is(err, os.ErrNotExist):
// noop
case err != nil:
return err
default:
// blob already exists
return nil
}
// Copy blob over
sourceFile, err := os.Open(path)
if err != nil {

View File

@ -782,6 +782,23 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
c.Status(http.StatusOK)
return
}
fmt.Println("HEIAHOEIHFOAHAEFHAO")
fmt.Println(c.GetHeader("X-Redirect-Create"))
if c.GetHeader("X-Redirect-Create") == "1" {
response := api.ServerConfig{ModelDir: path}
fmt.Println("Hit redirect")
resp, err := json.Marshal(response)
fmt.Println("marshalled response")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.Header("loc", string(resp))
fmt.Println("!!!!!!!!!", string(resp))
c.Status(http.StatusTemporaryRedirect)
return
}
layer, err := NewLayer(c.Request.Body, "")
if err != nil {