diff --git a/api/client.go b/api/client.go index 8ad68795f..3c47ac39b 100644 --- a/api/client.go +++ b/api/client.go @@ -403,3 +403,12 @@ func (c *Client) IsLocal() bool { return false } + +// EnvConfig returns the environment configuration for the server. +func (c *Client) ServerConfig(ctx context.Context) (*ServerConfig, error) { + var config ServerConfig + if err := c.do(ctx, http.MethodGet, "/api/config", nil, &config); err != nil { + return nil, err + } + return &config, nil +} diff --git a/api/types.go b/api/types.go index 65a99c763..3b10b17f4 100644 --- a/api/types.go +++ b/api/types.go @@ -451,6 +451,11 @@ type ModelDetails struct { QuantizationLevel string `json:"quantization_level"` } +// EnvConfig is the configuration for the environment. +type ServerConfig struct { + ModelDir string `json:"model_dir"` +} + func (m *Metrics) Summary() { if m.TotalDuration > 0 { fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration) diff --git a/cmd/cmd.go b/cmd/cmd.go index a7c9cb54d..d89265a7b 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -287,7 +287,7 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er // Resolve server to IP // Check if server is local if client.IsLocal() { - err := createBlobLocal(cmd, client, digest) + err := createBlobLocal(cmd.Context(), client, path, digest) if err == nil { return digest, nil } @@ -299,10 +299,66 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er return digest, nil } -func createBlobLocal(cmd *cobra.Command, client *api.Client, digest string) error { +func createBlobLocal(ctx context.Context, client *api.Client, path string, digest string) error { // This function should be called if the server is local // It should find the model directory, copy the blob over, and return the digest + // Get the model directory + config, err := client.ServerConfig(ctx) + if err != nil { + return err + } + + modelDir := config.ModelDir + + // Get blob destination + digest = strings.ReplaceAll(digest, ":", "-") + dest := filepath.Join(modelDir, "blobs", digest) + dirPath := filepath.Dir(dest) + if digest == "" { + dirPath = dest + } + + if err := os.MkdirAll(dirPath, 0o755); err != nil { + return err + } + + // Check blob exists + _, err = os.Stat(dest) + switch { + case errors.Is(err, os.ErrNotExist): + // noop + case err != nil: + return err + default: + // blob already exists + return nil + } + + // Copy blob over + sourceFile, err := os.Open(path) + if err != nil { + return fmt.Errorf("could not open source file: %v", err) + } + defer sourceFile.Close() + + destFile, err := os.Create(dest) + if err != nil { + return fmt.Errorf("could not create destination file: %v", err) + } + defer destFile.Close() + + _, err = io.Copy(destFile, sourceFile) + if err != nil { + return fmt.Errorf("error copying file: %v", err) + } + + err = destFile.Sync() + if err != nil { + return fmt.Errorf("error flushing file: %v", err) + } + + return nil } func RunHandler(cmd *cobra.Command, args []string) error { diff --git a/server/routes.go b/server/routes.go index 0d7ca003c..faf6ad6f1 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1069,6 +1069,7 @@ func (s *Server) GenerateRoutes() http.Handler { r.POST("/api/blobs/:digest", s.CreateBlobHandler) r.HEAD("/api/blobs/:digest", s.HeadBlobHandler) r.GET("/api/ps", s.ProcessHandler) + r.GET("/api/config", s.ConfigHandler) // Compatibility endpoints r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler) @@ -1422,3 +1423,7 @@ func handleScheduleError(c *gin.Context, name string, err error) { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } } + +func (s *Server) ConfigHandler(c *gin.Context) { + c.JSON(http.StatusOK, api.ServerConfig{ModelDir: envconfig.ModelsDir}) +}