reference license, template, system as files

This commit is contained in:
Michael Yang 2024-07-01 08:53:25 -07:00
parent e401a23d62
commit b7860f12ad
2 changed files with 62 additions and 36 deletions

View File

@ -85,8 +85,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
fsys := os.DirFS(createCtx) fsys := os.DirFS(createCtx)
for i := range modelfile.Commands { for i := range modelfile.Commands {
switch modelfile.Commands[i].Name { if slices.Contains([]string{"model", "adapter", "license", "template", "system"}, modelfile.Commands[i].Name) {
case "model", "adapter":
p := filepath.Clean(modelfile.Commands[i].Args) p := filepath.Clean(modelfile.Commands[i].Args)
fi, err := fs.Stat(fsys, p) fi, err := fs.Stat(fsys, p)
@ -96,6 +95,8 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
switch modelfile.Commands[i].Name {
case "model", "adapter":
if fi.IsDir() { if fi.IsDir() {
// this is likely a safetensors or pytorch directory // this is likely a safetensors or pytorch directory
// TODO make this work w/ adapters // TODO make this work w/ adapters
@ -120,6 +121,16 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
} }
case "license", "template", "system":
t, err := detectContentType(fsys, p)
if errors.Is(err, os.ErrNotExist) {
continue
} else if err != nil {
return err
} else if t != "text/plain" {
return fmt.Errorf("invalid content type for %s: %s", modelfile.Commands[i].Name, p)
}
}
digest, err := createBlob(cmd, client, fsys, p) digest, err := createBlob(cmd, client, fsys, p)
if err != nil { if err != nil {
@ -164,8 +175,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func zipFiles(fsys fs.FS, w io.Writer) error { func detectContentType(fsys fs.FS, name string) (string, error) {
detectContentType := func(name string) (string, error) {
f, err := fsys.Open(name) f, err := fsys.Open(name)
if err != nil { if err != nil {
return "", err return "", err
@ -179,8 +189,9 @@ func zipFiles(fsys fs.FS, w io.Writer) error {
contentType, _, _ := strings.Cut(http.DetectContentType(bts), ";") contentType, _, _ := strings.Cut(http.DetectContentType(bts), ";")
return contentType, nil return contentType, nil
} }
func zipFiles(fsys fs.FS, w io.Writer) error {
glob := func(pattern, contentType string) ([]string, error) { glob := func(pattern, contentType string) ([]string, error) {
matches, err := fs.Glob(fsys, pattern) matches, err := fs.Glob(fsys, pattern)
if err != nil { if err != nil {
@ -188,7 +199,7 @@ func zipFiles(fsys fs.FS, w io.Writer) error {
} }
for _, match := range matches { for _, match := range matches {
if ct, err := detectContentType(match); err != nil { if ct, err := detectContentType(fsys, match); err != nil {
return nil, err return nil, err
} else if ct != contentType { } else if ct != contentType {
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match) return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)

View File

@ -459,7 +459,22 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
}) })
} }
blob := strings.NewReader(c.Args) var blob io.Reader = strings.NewReader(c.Args)
if strings.HasPrefix(c.Args, "@") {
p, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
if err != nil {
return err
}
b, err := os.Open(p)
if err != nil {
return err
}
defer b.Close()
blob = b
}
layer, err := NewLayer(blob, mediatype) layer, err := NewLayer(blob, mediatype)
if err != nil { if err != nil {
return err return err