local copy

This commit is contained in:
Josh Yan 2024-07-05 15:05:58 -07:00
parent 163ee9a8b0
commit 6c0a8379f6
5 changed files with 112 additions and 14 deletions

View File

@ -316,7 +316,12 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
}
if err == nil {
err = createBlobLocal(path, dest)
err = localCopy(path, dest)
if err == nil {
return digest, nil
}
err = defaultCopy(path, dest)
if err == nil {
return digest, nil
}
@ -374,7 +379,7 @@ func getLocalPath(ctx context.Context, digest string) (string, error) {
return "", ErrBlobExists
}
func createBlobLocal(path string, dest string) error {
func defaultCopy(path string, dest string) error {
// This function should be called if the server is local
// It should find the model directory, copy the blob over, and return the digest
dirPath := filepath.Dir(dest)

23
cmd/copy_darwin.go Normal file
View File

@ -0,0 +1,23 @@
package cmd
import (
"os"
"path/filepath"
"golang.org/x/sys/unix"
)
func localCopy(src, target string) error {
dirPath := filepath.Dir(target)
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return err
}
err := unix.Clonefile(src, target, 0)
if err != nil {
return err
}
return nil
}

5
cmd/copy_linux.go Normal file
View File

@ -0,0 +1,5 @@
package cmd
func localCopy(src, target string) error {
return defaultCopy(src, target)
}

53
cmd/copy_windows.go Normal file
View File

@ -0,0 +1,53 @@
package cmd
import (
"os"
"path/filepath"
"syscall"
)
func localCopy(src, target string) error {
dirPath := filepath.Dir(target)
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return err
}
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
targetFile, err := os.Create(target)
if err != nil {
return err
}
defer targetFile.Close()
sourceHandle := syscall.Handle(sourceFile.Fd())
targetHandle := syscall.Handle(targetFile.Fd())
err = copyFileEx(sourceHandle, targetHandle)
if err != nil {
return err
}
return nil
}
func copyFileEx(srcHandle, dstHandle syscall.Handle) error {
kernel32 := syscall.NewLazyDLL("kernel32.dll")
copyFileEx := kernel32.NewProc("CopyFileExW")
r1, _, err := copyFileEx.Call(
uintptr(srcHandle),
uintptr(dstHandle),
0, 0, 0, 0)
if r1 == 0 {
return err
}
return nil
}

View File

@ -927,12 +927,13 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
}
}
fmt.Println("path2", c.Param("digest"))
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
fmt.Println("path1", path)
_, err = os.Stat(path)
switch {
case errors.Is(err, os.ErrNotExist):
@ -944,8 +945,10 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
c.Status(http.StatusOK)
return
}
fmt.Println("hello")
fmt.Println(s.IsLocal(c))
if c.GetHeader("X-Redirect-Create") == "1" && s.IsLocal(c) {
fmt.Println("entered redirect")
c.Header("LocalLocation", path)
c.Status(http.StatusTemporaryRedirect)
return
@ -966,25 +969,32 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
}
func (s *Server) IsLocal(c *gin.Context) bool {
fmt.Println("entered islocal")
fmt.Println(c.GetHeader("Authorization"), " is authorization")
if authz := c.GetHeader("Authorization"); authz != "" {
parts := strings.Split(authz, ":")
if len(parts) != 3 {
fmt.Println("failed at lenParts")
return false
}
clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
if err != nil {
fmt.Println("failed at parseAuthorizedKey")
return false
}
// partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
partialRequestData, err := base64.StdEncoding.DecodeString(parts[1])
requestData, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
fmt.Println("failed at decodeString")
return false
}
partialRequestDataParts := strings.Split(string(partialRequestData), ",")
if len(partialRequestDataParts) != 4 {
partialRequestDataParts := strings.Split(string(requestData), ",")
if len(partialRequestDataParts) != 3 {
fmt.Println("failed at lenPartialRequestDataParts")
return false
}
@ -1007,22 +1017,24 @@ func (s *Server) IsLocal(c *gin.Context) bool {
signature, err := base64.StdEncoding.DecodeString(parts[2])
if err != nil {
fmt.Println("failed at decodeString stdEncoding")
return false
}
if err := clientPublicKey.Verify([]byte(requestData), &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
fmt.Println("failed at verify")
fmt.Println(err)
return false
}
serverPublicKey, err := auth.GetPublicKey()
if err != nil {
fmt.Println("failed at getPublicKey")
log.Fatal(err)
}
_, key, _ := bytes.Cut(bytes.TrimSpace(ssh.MarshalAuthorizedKey(serverPublicKey)), []byte(" "))
requestData := fmt.Sprintf("%s,%s", key, partialRequestData)
if err := clientPublicKey.Verify([]byte(requestData), &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
return false
}
if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
fmt.Println("true")
return true
}