From 6c0a8379f6715cce337f06eecf3c83c7935df23b Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Fri, 5 Jul 2024 15:05:58 -0700 Subject: [PATCH] local copy --- cmd/cmd.go | 9 ++++++-- cmd/copy_darwin.go | 23 ++++++++++++++++++++ cmd/copy_linux.go | 5 +++++ cmd/copy_windows.go | 53 +++++++++++++++++++++++++++++++++++++++++++++ server/routes.go | 36 ++++++++++++++++++++---------- 5 files changed, 112 insertions(+), 14 deletions(-) create mode 100644 cmd/copy_darwin.go create mode 100644 cmd/copy_linux.go create mode 100644 cmd/copy_windows.go diff --git a/cmd/cmd.go b/cmd/cmd.go index f4fa4fca8..710e14304 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -316,7 +316,12 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er } if err == nil { - err = createBlobLocal(path, dest) + err = localCopy(path, dest) + if err == nil { + return digest, nil + } + + err = defaultCopy(path, dest) if err == nil { return digest, nil } @@ -374,7 +379,7 @@ func getLocalPath(ctx context.Context, digest string) (string, error) { return "", ErrBlobExists } -func createBlobLocal(path string, dest string) error { +func defaultCopy(path string, dest string) error { // This function should be called if the server is local // It should find the model directory, copy the blob over, and return the digest dirPath := filepath.Dir(dest) diff --git a/cmd/copy_darwin.go b/cmd/copy_darwin.go new file mode 100644 index 000000000..fab84d170 --- /dev/null +++ b/cmd/copy_darwin.go @@ -0,0 +1,23 @@ +package cmd + +import ( + "os" + "path/filepath" + + "golang.org/x/sys/unix" +) + +func localCopy(src, target string) error { + dirPath := filepath.Dir(target) + + if err := os.MkdirAll(dirPath, 0o755); err != nil { + return err + } + + err := unix.Clonefile(src, target, 0) + if err != nil { + return err + } + + return nil +} diff --git a/cmd/copy_linux.go b/cmd/copy_linux.go new file mode 100644 index 000000000..e203a390b --- /dev/null +++ b/cmd/copy_linux.go @@ -0,0 +1,5 @@ +package cmd + +func localCopy(src, target string) error { + return defaultCopy(src, target) +} diff --git a/cmd/copy_windows.go b/cmd/copy_windows.go new file mode 100644 index 000000000..18cc9b40e --- /dev/null +++ b/cmd/copy_windows.go @@ -0,0 +1,53 @@ +package cmd + +import ( + "os" + "path/filepath" + "syscall" +) + +func localCopy(src, target string) error { + dirPath := filepath.Dir(target) + + if err := os.MkdirAll(dirPath, 0o755); err != nil { + return err + } + + sourceFile, err := os.Open(src) + if err != nil { + return err + } + defer sourceFile.Close() + + targetFile, err := os.Create(target) + if err != nil { + return err + } + defer targetFile.Close() + + sourceHandle := syscall.Handle(sourceFile.Fd()) + targetHandle := syscall.Handle(targetFile.Fd()) + + err = copyFileEx(sourceHandle, targetHandle) + if err != nil { + return err + } + + return nil +} + +func copyFileEx(srcHandle, dstHandle syscall.Handle) error { + kernel32 := syscall.NewLazyDLL("kernel32.dll") + copyFileEx := kernel32.NewProc("CopyFileExW") + + r1, _, err := copyFileEx.Call( + uintptr(srcHandle), + uintptr(dstHandle), + 0, 0, 0, 0) + + if r1 == 0 { + return err + } + + return nil +} diff --git a/server/routes.go b/server/routes.go index 7525f139c..06acf83e2 100644 --- a/server/routes.go +++ b/server/routes.go @@ -927,12 +927,13 @@ func (s *Server) CreateBlobHandler(c *gin.Context) { } } + fmt.Println("path2", c.Param("digest")) path, err := GetBlobsPath(c.Param("digest")) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - + fmt.Println("path1", path) _, err = os.Stat(path) switch { case errors.Is(err, os.ErrNotExist): @@ -944,8 +945,10 @@ func (s *Server) CreateBlobHandler(c *gin.Context) { c.Status(http.StatusOK) return } - + fmt.Println("hello") + fmt.Println(s.IsLocal(c)) if c.GetHeader("X-Redirect-Create") == "1" && s.IsLocal(c) { + fmt.Println("entered redirect") c.Header("LocalLocation", path) c.Status(http.StatusTemporaryRedirect) return @@ -966,25 +969,32 @@ func (s *Server) CreateBlobHandler(c *gin.Context) { } func (s *Server) IsLocal(c *gin.Context) bool { + fmt.Println("entered islocal") + fmt.Println(c.GetHeader("Authorization"), " is authorization") if authz := c.GetHeader("Authorization"); authz != "" { + parts := strings.Split(authz, ":") if len(parts) != 3 { + fmt.Println("failed at lenParts") return false } clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0]))) if err != nil { + fmt.Println("failed at parseAuthorizedKey") return false } // partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce - partialRequestData, err := base64.StdEncoding.DecodeString(parts[1]) + requestData, err := base64.StdEncoding.DecodeString(parts[1]) if err != nil { + fmt.Println("failed at decodeString") return false } - partialRequestDataParts := strings.Split(string(partialRequestData), ",") - if len(partialRequestDataParts) != 4 { + partialRequestDataParts := strings.Split(string(requestData), ",") + if len(partialRequestDataParts) != 3 { + fmt.Println("failed at lenPartialRequestDataParts") return false } @@ -1007,22 +1017,24 @@ func (s *Server) IsLocal(c *gin.Context) bool { signature, err := base64.StdEncoding.DecodeString(parts[2]) if err != nil { + fmt.Println("failed at decodeString stdEncoding") + return false + } + + if err := clientPublicKey.Verify([]byte(requestData), &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil { + fmt.Println("failed at verify") + fmt.Println(err) return false } serverPublicKey, err := auth.GetPublicKey() if err != nil { + fmt.Println("failed at getPublicKey") log.Fatal(err) } - _, key, _ := bytes.Cut(bytes.TrimSpace(ssh.MarshalAuthorizedKey(serverPublicKey)), []byte(" ")) - requestData := fmt.Sprintf("%s,%s", key, partialRequestData) - - if err := clientPublicKey.Verify([]byte(requestData), &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil { - return false - } - if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) { + fmt.Println("true") return true }