Compare commits

...

18 Commits

Author SHA1 Message Date
Michael Yang
cedae0d17a Merge pull request #1347 from jshph/adapter-hash
Fix adapter loading from SHA hash
2023-12-01 11:08:25 -08:00
Joshua Pham
bb80a597db Fix adapter loading from SHA hash 2023-12-01 13:50:55 -05:00
Patrick Devine
6681d37861 allow setting the system and template for prompts in the repl (#1335) 2023-12-01 09:28:35 -08:00
Michael Yang
0409c1fa59 docker: set PATH, LD_LIBRARY_PATH, and capabilities (#1336)
* docker: set PATH, LD_LIBRARY_PATH, and capabilities

* example: update k8s gpu manifest
2023-11-30 21:16:56 -08:00
Michael Yang
b56e92470a Merge pull request #1229 from jmorganca/mxyng/calculate-as-you-go
revert checksum calculation to calculate-as-you-go
2023-11-30 10:54:38 -08:00
Jeffrey Morgan
5687f1a0cf fix unexpected end of response errors when cancelling in ollama run 2023-11-30 00:30:21 -05:00
James Radtke
7eda3d0c55 Corrected transposed 129 to 192 for OLLAMA_ORIGINS example (#1325) 2023-11-29 22:44:17 -05:00
Bruce MacDonald
7194a07d4d Add chatd to example projects 2023-11-29 21:18:21 -05:00
Michael Yang
13efd5f218 upload: fix PUT retry 2023-11-29 16:38:35 -08:00
Michael Yang
c4bdfffd96 upload: separate progress tracking 2023-11-29 16:38:33 -08:00
Michael Yang
26c63418e0 new hasher 2023-11-29 14:52:41 -08:00
Michael Yang
2799784ac8 revert checksum calculation to calculate-as-you-go 2023-11-29 13:47:58 -08:00
Alec Hammond
91897a606f Add OllamaEmbeddings to python LangChain example (#994)
* Add OllamaEmbeddings to python LangChain example

* typo

---------

Co-authored-by: Alec Hammond <alechammond@fb.com>
2023-11-29 16:25:39 -05:00
Bruce MacDonald
96122b7271 validate model tags on copy (#1323) 2023-11-29 15:54:29 -05:00
jeremiahbuckley
39be7fdb98 fix rhel cuda install (#1321)
Co-authored-by: Cloud User <azureuser@testgpu2.hqzwom21okjenksna4y3c4ymjd.phxx.internal.cloudapp.net>
2023-11-29 14:55:15 -05:00
Timothy Jaeryang Baek
c2e3b89176 fix: disable ':' in tag names (#1280)
Co-authored-by: rootedbox
2023-11-29 13:33:45 -05:00
Patrick Devine
cde31cb220 Allow setting parameters in the REPL (#1294) 2023-11-29 09:56:42 -08:00
ToasterUwU
63097607b2 Correct MacOS Host port example (#1301) 2023-11-29 11:44:03 -05:00
12 changed files with 325 additions and 150 deletions

View File

@@ -19,5 +19,11 @@ RUN apt-get update && apt-get install -y ca-certificates
COPY --from=0 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama COPY --from=0 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
EXPOSE 11434 EXPOSE 11434
ENV OLLAMA_HOST 0.0.0.0 ENV OLLAMA_HOST 0.0.0.0
# set some environment variable for better NVIDIA compatibility
ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
ENTRYPOINT ["/bin/ollama"] ENTRYPOINT ["/bin/ollama"]
CMD ["serve"] CMD ["serve"]

View File

@@ -233,6 +233,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [big-AGI](https://github.com/enricoros/big-agi/blob/main/docs/config-ollama.md) - [big-AGI](https://github.com/enricoros/big-agi/blob/main/docs/config-ollama.md)
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core) - [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
- [Amica](https://github.com/semperai/amica) - [Amica](https://github.com/semperai/amica)
- [chatd](https://github.com/BruceMacD/chatd)
### Terminal ### Terminal

View File

@@ -6,6 +6,7 @@ import (
"math" "math"
"os" "os"
"reflect" "reflect"
"strconv"
"strings" "strings"
"time" "time"
) )
@@ -360,3 +361,63 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
return nil return nil
} }
// FormatParams converts specified parameter options to their correct types
func FormatParams(params map[string][]string) (map[string]interface{}, error) {
opts := Options{}
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
// build map of json struct tags to their types
jsonOpts := make(map[string]reflect.StructField)
for _, field := range reflect.VisibleFields(typeOpts) {
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
if jsonTag != "" {
jsonOpts[jsonTag] = field
}
}
out := make(map[string]interface{})
// iterate params and set values based on json struct tags
for key, vals := range params {
if opt, ok := jsonOpts[key]; !ok {
return nil, fmt.Errorf("unknown parameter '%s'", key)
} else {
field := valueOpts.FieldByName(opt.Name)
if field.IsValid() && field.CanSet() {
switch field.Kind() {
case reflect.Float32:
floatVal, err := strconv.ParseFloat(vals[0], 32)
if err != nil {
return nil, fmt.Errorf("invalid float value %s", vals)
}
out[key] = float32(floatVal)
case reflect.Int:
intVal, err := strconv.ParseInt(vals[0], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid int value %s", vals)
}
out[key] = intVal
case reflect.Bool:
boolVal, err := strconv.ParseBool(vals[0])
if err != nil {
return nil, fmt.Errorf("invalid bool value %s", vals)
}
out[key] = boolVal
case reflect.String:
out[key] = vals[0]
case reflect.Slice:
// TODO: only string slices are supported right now
out[key] = vals
default:
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
}
}
}
}
return out, nil
}

View File

@@ -412,10 +412,19 @@ func PullHandler(cmd *cobra.Command, args []string) error {
} }
func RunGenerate(cmd *cobra.Command, args []string) error { func RunGenerate(cmd *cobra.Command, args []string) error {
interactive := true
opts := generateOptions{
Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]interface{}{},
}
format, err := cmd.Flags().GetString("format") format, err := cmd.Flags().GetString("format")
if err != nil { if err != nil {
return err return err
} }
opts.Format = format
prompts := args[1:] prompts := args[1:]
@@ -427,34 +436,40 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
} }
prompts = append([]string{string(in)}, prompts...) prompts = append([]string{string(in)}, prompts...)
opts.WordWrap = false
interactive = false
} }
opts.Prompt = strings.Join(prompts, " ")
// output is being piped if len(prompts) > 0 {
if !term.IsTerminal(int(os.Stdout.Fd())) { interactive = false
return generate(cmd, args[0], strings.Join(prompts, " "), false, format)
} }
wordWrap := os.Getenv("TERM") == "xterm-256color"
nowrap, err := cmd.Flags().GetBool("nowordwrap") nowrap, err := cmd.Flags().GetBool("nowordwrap")
if err != nil { if err != nil {
return err return err
} }
if nowrap { opts.WordWrap = !nowrap
wordWrap = false
if !interactive {
return generate(cmd, opts)
} }
// prompts are provided via stdin or args so don't enter interactive mode return generateInteractive(cmd, opts)
if len(prompts) > 0 {
return generate(cmd, args[0], strings.Join(prompts, " "), wordWrap, format)
}
return generateInteractive(cmd, args[0], wordWrap, format)
} }
type generateContextKey string type generateContextKey string
func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format string) error { type generateOptions struct {
Model string
Prompt string
WordWrap bool
Format string
System string
Template string
Options map[string]interface{}
}
func generate(cmd *cobra.Command, opts generateOptions) error {
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
return err return err
@@ -475,7 +490,7 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
termWidth, _, err := term.GetSize(int(os.Stdout.Fd())) termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil { if err != nil {
wordWrap = false opts.WordWrap = false
} }
cancelCtx, cancel := context.WithCancel(context.Background()) cancelCtx, cancel := context.WithCancel(context.Background())
@@ -483,24 +498,30 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
sigChan := make(chan os.Signal, 1) sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT) signal.Notify(sigChan, syscall.SIGINT)
var abort bool
go func() { go func() {
<-sigChan <-sigChan
cancel() cancel()
abort = true
}() }()
var currentLineLength int var currentLineLength int
var wordBuffer string var wordBuffer string
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, Format: format} request := api.GenerateRequest{
Model: opts.Model,
Prompt: opts.Prompt,
Context: generateContext,
Format: opts.Format,
System: opts.System,
Template: opts.Template,
Options: opts.Options,
}
fn := func(response api.GenerateResponse) error { fn := func(response api.GenerateResponse) error {
p.StopAndClear() p.StopAndClear()
latest = response latest = response
if wordWrap { if opts.WordWrap {
for _, ch := range response.Response { for _, ch := range response.Response {
if currentLineLength+1 > termWidth-5 { if currentLineLength+1 > termWidth-5 {
// backtrack the length of the last word and clear to the end of the line // backtrack the length of the last word and clear to the end of the line
@@ -529,21 +550,18 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
} }
if err := client.Generate(cancelCtx, &request, fn); err != nil { if err := client.Generate(cancelCtx, &request, fn); err != nil {
if strings.Contains(err.Error(), "context canceled") && abort { if errors.Is(err, context.Canceled) {
return nil return nil
} }
return err return err
} }
if prompt != "" { if opts.Prompt != "" {
fmt.Println() fmt.Println()
fmt.Println() fmt.Println()
} }
if !latest.Done { if !latest.Done {
if abort { return nil
return nil
}
return errors.New("unexpected end of response")
} }
verbose, err := cmd.Flags().GetBool("verbose") verbose, err := cmd.Flags().GetBool("verbose")
@@ -562,9 +580,22 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
return nil return nil
} }
func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format string) error { type MultilineState int
const (
MultilineNone MultilineState = iota
MultilinePrompt
MultilineSystem
MultilineTemplate
)
func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
// load the model // load the model
if err := generate(cmd, model, "", false, ""); err != nil { loadOpts := generateOptions{
Model: opts.Model,
Prompt: "",
}
if err := generate(cmd, loadOpts); err != nil {
return err return err
} }
@@ -581,14 +612,17 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
usageSet := func() { usageSet := func() {
fmt.Fprintln(os.Stderr, "Available Commands:") fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set history Enable history") fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
fmt.Fprintln(os.Stderr, " /set nohistory Disable history") fmt.Fprintln(os.Stderr, " /set system <string> Set system prompt")
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap") fmt.Fprintln(os.Stderr, " /set template <string> Set prompt template")
fmt.Fprintln(os.Stderr, " /set nowordwrap Disable wordwrap") fmt.Fprintln(os.Stderr, " /set history Enable history")
fmt.Fprintln(os.Stderr, " /set format json Enable JSON mode") fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting") fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats") fmt.Fprintln(os.Stderr, " /set nowordwrap Disable wordwrap")
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats") fmt.Fprintln(os.Stderr, " /set format json Enable JSON mode")
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "")
} }
@@ -602,6 +636,22 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "")
} }
// only list out the most common parameters
usageParameters := func() {
fmt.Fprintln(os.Stderr, "Available Parameters:")
fmt.Fprintln(os.Stderr, " /set parameter seed <int> Random number seed")
fmt.Fprintln(os.Stderr, " /set parameter num_predict <int> Max number of tokens to predict")
fmt.Fprintln(os.Stderr, " /set parameter top_k <int> Pick from top k num of tokens")
fmt.Fprintln(os.Stderr, " /set parameter top_p <float> Pick token based on sum of probabilities")
fmt.Fprintln(os.Stderr, " /set parameter num_ctx <int> Set the context size")
fmt.Fprintln(os.Stderr, " /set parameter temperature <float> Set creativity level")
fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty <float> How strongly to penalize repetitions")
fmt.Fprintln(os.Stderr, " /set parameter repeat_last_n <int> Set how far back to look for repetitions")
fmt.Fprintln(os.Stderr, " /set parameter num_gpu <int> The number of layers to send to the GPU")
fmt.Fprintln(os.Stderr, " /set parameter stop \"<string>\", ... Set the stop parameters")
fmt.Fprintln(os.Stderr, "")
}
scanner, err := readline.New(readline.Prompt{ scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ", Prompt: ">>> ",
AltPrompt: "... ", AltPrompt: "... ",
@@ -615,6 +665,7 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
fmt.Print(readline.StartBracketedPaste) fmt.Print(readline.StartBracketedPaste)
defer fmt.Printf(readline.EndBracketedPaste) defer fmt.Printf(readline.EndBracketedPaste)
var multiline MultilineState
var prompt string var prompt string
for { for {
@@ -649,8 +700,21 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
prompt = strings.TrimPrefix(prompt, `"""`) prompt = strings.TrimPrefix(prompt, `"""`)
scanner.Prompt.UseAlt = false scanner.Prompt.UseAlt = false
switch multiline {
case MultilineSystem:
opts.System = prompt
prompt = ""
fmt.Println("Set system template.\n")
case MultilineTemplate:
opts.Template = prompt
prompt = ""
fmt.Println("Set model template.\n")
}
multiline = MultilineNone
case strings.HasPrefix(line, `"""`) && len(prompt) == 0: case strings.HasPrefix(line, `"""`) && len(prompt) == 0:
scanner.Prompt.UseAlt = true scanner.Prompt.UseAlt = true
multiline = MultilinePrompt
prompt += line + "\n" prompt += line + "\n"
continue continue
case scanner.Pasting: case scanner.Pasting:
@@ -670,10 +734,10 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
case "nohistory": case "nohistory":
scanner.HistoryDisable() scanner.HistoryDisable()
case "wordwrap": case "wordwrap":
wordWrap = true opts.WordWrap = true
fmt.Println("Set 'wordwrap' mode.") fmt.Println("Set 'wordwrap' mode.")
case "nowordwrap": case "nowordwrap":
wordWrap = false opts.WordWrap = false
fmt.Println("Set 'nowordwrap' mode.") fmt.Println("Set 'nowordwrap' mode.")
case "verbose": case "verbose":
cmd.Flags().Set("verbose", "true") cmd.Flags().Set("verbose", "true")
@@ -685,12 +749,59 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
if len(args) < 3 || args[2] != "json" { if len(args) < 3 || args[2] != "json" {
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'") fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
} else { } else {
format = args[2] opts.Format = args[2]
fmt.Printf("Set format to '%s' mode.\n", args[2]) fmt.Printf("Set format to '%s' mode.\n", args[2])
} }
case "noformat": case "noformat":
format = "" opts.Format = ""
fmt.Println("Disabled format.") fmt.Println("Disabled format.")
case "parameter":
if len(args) < 4 {
usageParameters()
continue
}
var params []string
for _, p := range args[3:] {
params = append(params, p)
}
fp, err := api.FormatParams(map[string][]string{args[2]: params})
if err != nil {
fmt.Printf("Couldn't set parameter: %q\n\n", err)
continue
}
fmt.Printf("Set parameter '%s' to '%s'\n\n", args[2], strings.Join(params, ", "))
opts.Options[args[2]] = fp[args[2]]
case "system", "template":
if len(args) < 3 {
usageSet()
continue
}
line := strings.Join(args[2:], " ")
line = strings.TrimPrefix(line, `"""`)
if strings.HasPrefix(args[2], `"""`) {
cut, found := strings.CutSuffix(line, `"""`)
prompt += cut + "\n"
if found {
opts.System = prompt
if args[1] == "system" {
fmt.Println("Set system template.\n")
} else {
fmt.Println("Set prompt template.\n")
}
prompt = ""
} else {
prompt = `"""` + prompt
if args[1] == "system" {
multiline = MultilineSystem
} else {
multiline = MultilineTemplate
}
scanner.Prompt.UseAlt = true
}
} else {
opts.System = line
fmt.Println("Set system template.\n")
}
default: default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1]) fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
} }
@@ -705,7 +816,7 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
fmt.Println("error: couldn't connect to ollama server") fmt.Println("error: couldn't connect to ollama server")
return err return err
} }
resp, err := client.Show(cmd.Context(), &api.ShowRequest{Name: model}) resp, err := client.Show(cmd.Context(), &api.ShowRequest{Name: opts.Model})
if err != nil { if err != nil {
fmt.Println("error: couldn't get model") fmt.Println("error: couldn't get model")
return err return err
@@ -724,19 +835,33 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
if resp.Parameters == "" { if resp.Parameters == "" {
fmt.Print("No parameters were specified for this model.\n\n") fmt.Print("No parameters were specified for this model.\n\n")
} else { } else {
if len(opts.Options) > 0 {
fmt.Println("User defined parameters:")
for k, v := range opts.Options {
fmt.Printf("%-*s %v\n", 30, k, v)
}
fmt.Println()
}
fmt.Println("Model defined parameters:")
fmt.Println(resp.Parameters) fmt.Println(resp.Parameters)
} }
case "system": case "system":
if resp.System == "" { switch {
case opts.System != "":
fmt.Println(opts.System + "\n")
case resp.System != "":
fmt.Println(resp.System + "\n")
default:
fmt.Print("No system prompt was specified for this model.\n\n") fmt.Print("No system prompt was specified for this model.\n\n")
} else {
fmt.Println(resp.System)
} }
case "template": case "template":
if resp.Template == "" { switch {
fmt.Print("No prompt template was specified for this model.\n\n") case opts.Template != "":
} else { fmt.Println(opts.Template + "\n")
case resp.Template != "":
fmt.Println(resp.Template) fmt.Println(resp.Template)
default:
fmt.Print("No prompt template was specified for this model.\n\n")
} }
default: default:
fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1]) fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
@@ -766,8 +891,9 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
prompt += line prompt += line
} }
if len(prompt) > 0 && prompt[0] != '/' { if len(prompt) > 0 && multiline == MultilineNone {
if err := generate(cmd, model, prompt, wordWrap, format); err != nil { opts.Prompt = prompt
if err := generate(cmd, opts); err != nil {
return err return err
} }

View File

@@ -23,7 +23,7 @@ Ollama binds to 127.0.0.1 port 11434 by default. Change the bind address with th
On macOS: On macOS:
```bash ```bash
OLLAMA_HOST=0.0.0.0:11435 ollama serve OLLAMA_HOST=0.0.0.0:11434 ollama serve
``` ```
On Linux: On Linux:
@@ -59,7 +59,7 @@ OLLAMA_ORIGINS=http://192.168.1.1:*,https://example.com ollama serve
On Linux: On Linux:
```bash ```bash
echo 'Environment="OLLAMA_ORIGINS=http://129.168.1.1:*,https://example.com"' >>/etc/systemd/system/ollama.service.d/environment.conf echo 'Environment="OLLAMA_ORIGINS=http://192.168.1.1:*,https://example.com"' >>/etc/systemd/system/ollama.service.d/environment.conf
``` ```
Reload `systemd` and restart Ollama: Reload `systemd` and restart Ollama:

View File

@@ -42,12 +42,13 @@ text_splitter=RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
all_splits = text_splitter.split_documents(data) all_splits = text_splitter.split_documents(data)
``` ```
It's split up, but we have to find the relevant splits and then submit those to the model. We can do this by creating embeddings and storing them in a vector database. For now, we don't have embeddings built in to Ollama, though we will be adding that soon, so for now, we can use the GPT4All library for that. We will use ChromaDB in this example for a vector database. `pip install GPT4All chromadb` It's split up, but we have to find the relevant splits and then submit those to the model. We can do this by creating embeddings and storing them in a vector database. We can use Ollama directly to instantiate an embedding model. We will use ChromaDB in this example for a vector database. `pip install GPT4All chromadb`
```python ```python
from langchain.embeddings import GPT4AllEmbeddings from langchain.embeddings import OllamaEmbeddings
from langchain.vectorstores import Chroma from langchain.vectorstores import Chroma
vectorstore = Chroma.from_documents(documents=all_splits, embedding=GPT4AllEmbeddings()) oembed = OllamaEmbeddings(base_url="http://localhost:11434", model="llama2")
vectorstore = Chroma.from_documents(documents=all_splits, embedding=oembed)
``` ```
Now let's ask a question from the document. **Who was Neleus, and who is in his family?** Neleus is a character in the Odyssey, and the answer can be found in our text. Now let's ask a question from the document. **Who was Neleus, and who is in his family?** Neleus is a character in the Odyssey, and the answer can be found in our text.

View File

@@ -25,9 +25,11 @@ spec:
image: ollama/ollama:latest image: ollama/ollama:latest
env: env:
- name: PATH - name: PATH
value: /usr/local/nvidia/bin:/usr/local/nvidia/lib64:/usr/bin:/usr/sbin:/bin:/sbin value: /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
- name: LD_LIBRARY_PATH - name: LD_LIBRARY_PATH
value: /usr/local/nvidia/lib64 value: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
- name: NVIDIA_DRIVER_CAPABILITIES
value: compute,utility
ports: ports:
- name: http - name: http
containerPort: 11434 containerPort: 11434

View File

@@ -217,7 +217,7 @@ fi
if ! check_gpu nvidia-smi || [ -z "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then if ! check_gpu nvidia-smi || [ -z "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
case $OS_NAME in case $OS_NAME in
centos|rhel) install_cuda_driver_yum 'rhel' $OS_VERSION ;; centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;;
rocky) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -c1) ;; rocky) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -c1) ;;
fedora) install_cuda_driver_yum $OS_NAME $OS_VERSION ;; fedora) install_cuda_driver_yum $OS_NAME $OS_VERSION ;;
amzn) install_cuda_driver_yum 'fedora' '35' ;; amzn) install_cuda_driver_yum 'fedora' '35' ;;

View File

@@ -14,7 +14,6 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
@@ -376,6 +375,15 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
layer.MediaType = mediatype layer.MediaType = mediatype
layers = append(layers, layer) layers = append(layers, layer)
case "adapter": case "adapter":
if strings.HasPrefix(c.Args, "@") {
blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
if err != nil {
return err
}
c.Args = blobPath
}
fn(api.ProgressResponse{Status: "creating adapter layer"}) fn(api.ProgressResponse{Status: "creating adapter layer"})
bin, err := os.Open(realpath(modelFileDir, c.Args)) bin, err := os.Open(realpath(modelFileDir, c.Args))
if err != nil { if err != nil {
@@ -426,7 +434,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
if len(params) > 0 { if len(params) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"}) fn(api.ProgressResponse{Status: "creating parameters layer"})
formattedParams, err := formatParams(params) formattedParams, err := api.FormatParams(params)
if err != nil { if err != nil {
return err return err
} }
@@ -581,64 +589,6 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
return newLayer, nil return newLayer, nil
} }
// formatParams converts specified parameter options to their correct types
func formatParams(params map[string][]string) (map[string]interface{}, error) {
opts := api.Options{}
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
// build map of json struct tags to their types
jsonOpts := make(map[string]reflect.StructField)
for _, field := range reflect.VisibleFields(typeOpts) {
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
if jsonTag != "" {
jsonOpts[jsonTag] = field
}
}
out := make(map[string]interface{})
// iterate params and set values based on json struct tags
for key, vals := range params {
if opt, ok := jsonOpts[key]; ok {
field := valueOpts.FieldByName(opt.Name)
if field.IsValid() && field.CanSet() {
switch field.Kind() {
case reflect.Float32:
floatVal, err := strconv.ParseFloat(vals[0], 32)
if err != nil {
return nil, fmt.Errorf("invalid float value %s", vals)
}
out[key] = float32(floatVal)
case reflect.Int:
intVal, err := strconv.ParseInt(vals[0], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid int value %s", vals)
}
out[key] = intVal
case reflect.Bool:
boolVal, err := strconv.ParseBool(vals[0])
if err != nil {
return nil, fmt.Errorf("invalid bool value %s", vals)
}
out[key] = boolVal
case reflect.String:
out[key] = vals[0]
case reflect.Slice:
// TODO: only string slices are supported right now
out[key] = vals
default:
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
}
}
}
}
return out, nil
}
func getLayerDigests(layers []*LayerReader) ([]string, error) { func getLayerDigests(layers []*LayerReader) ([]string, error) {
var digests []string var digests []string
for _, l := range layers { for _, l := range layers {

View File

@@ -67,6 +67,20 @@ func ParseModelPath(name string) ModelPath {
return mp return mp
} }
var errModelPathInvalid = errors.New("invalid model path")
func (mp ModelPath) Validate() error {
if mp.Repository == "" {
return fmt.Errorf("%w: model repository name is required", errModelPathInvalid)
}
if strings.Contains(mp.Tag, ":") {
return fmt.Errorf("%w: ':' (colon) is not allowed in tag names", errModelPathInvalid)
}
return nil
}
func (mp ModelPath) GetNamespaceRepository() string { func (mp ModelPath) GetNamespaceRepository() string {
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository) return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
} }

View File

@@ -416,6 +416,11 @@ func CreateModelHandler(c *gin.Context) {
return return
} }
if err := ParseModelPath(req.Name).Validate(); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Path == "" && req.Modelfile == "" { if req.Path == "" && req.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
return return
@@ -640,6 +645,11 @@ func CopyModelHandler(c *gin.Context) {
return return
} }
if err := ParseModelPath(req.Destination).Validate(); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := CopyModel(req.Source, req.Destination); err != nil { if err := CopyModel(req.Source, req.Destination); err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)}) c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)})

View File

@@ -5,6 +5,7 @@ import (
"crypto/md5" "crypto/md5"
"errors" "errors"
"fmt" "fmt"
"hash"
"io" "io"
"log" "log"
"math" "math"
@@ -102,7 +103,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg
} }
// set part.N to the current number of parts // set part.N to the current number of parts
b.Parts = append(b.Parts, blobUploadPart{blobUpload: b, N: len(b.Parts), Offset: offset, Size: size}) b.Parts = append(b.Parts, blobUploadPart{N: len(b.Parts), Offset: offset, Size: size})
offset += size offset += size
} }
@@ -147,14 +148,13 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
g.Go(func() error { g.Go(func() error {
var err error var err error
for try := 0; try < maxRetries; try++ { for try := 0; try < maxRetries; try++ {
err = b.uploadChunk(inner, http.MethodPatch, requestURL, part, opts) err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts)
switch { switch {
case errors.Is(err, context.Canceled): case errors.Is(err, context.Canceled):
return err return err
case errors.Is(err, errMaxRetriesExceeded): case errors.Is(err, errMaxRetriesExceeded):
return err return err
case err != nil: case err != nil:
part.Reset()
sleep := time.Second * time.Duration(math.Pow(2, float64(try))) sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep) log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
time.Sleep(sleep) time.Sleep(sleep)
@@ -176,17 +176,10 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
requestURL := <-b.nextURL requestURL := <-b.nextURL
var sb strings.Builder
// calculate md5 checksum and add it to the commit request // calculate md5 checksum and add it to the commit request
var sb strings.Builder
for _, part := range b.Parts { for _, part := range b.Parts {
hash := md5.New() sb.Write(part.Sum(nil))
if _, err := io.Copy(hash, io.NewSectionReader(b.file, part.Offset, part.Size)); err != nil {
b.err = err
return
}
sb.Write(hash.Sum(nil))
} }
md5sum := md5.Sum([]byte(sb.String())) md5sum := md5.Sum([]byte(sb.String()))
@@ -201,27 +194,25 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
headers.Set("Content-Length", "0") headers.Set("Content-Length", "0")
for try := 0; try < maxRetries; try++ { for try := 0; try < maxRetries; try++ {
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts) var resp *http.Response
if err != nil { resp, err = makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
b.err = err if errors.Is(err, context.Canceled) {
if errors.Is(err, context.Canceled) { break
return } else if err != nil {
}
sleep := time.Second * time.Duration(math.Pow(2, float64(try))) sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
log.Printf("%s complete upload attempt %d failed: %v, retrying in %s", b.Digest[7:19], try, err, sleep) log.Printf("%s complete upload attempt %d failed: %v, retrying in %s", b.Digest[7:19], try, err, sleep)
time.Sleep(sleep) time.Sleep(sleep)
continue continue
} }
defer resp.Body.Close() defer resp.Body.Close()
break
b.err = nil
b.done = true
return
} }
b.err = err
b.done = true
} }
func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error { func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
headers := make(http.Header) headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream") headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", fmt.Sprintf("%d", part.Size)) headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
@@ -232,8 +223,13 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
} }
sr := io.NewSectionReader(b.file, part.Offset, part.Size) sr := io.NewSectionReader(b.file, part.Offset, part.Size)
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, part), opts)
md5sum := md5.New()
w := &progressWriter{blobUpload: b}
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
if err != nil { if err != nil {
w.Rollback()
return err return err
} }
defer resp.Body.Close() defer resp.Body.Close()
@@ -245,11 +241,13 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
nextURL, err := url.Parse(location) nextURL, err := url.Parse(location)
if err != nil { if err != nil {
w.Rollback()
return err return err
} }
switch { switch {
case resp.StatusCode == http.StatusTemporaryRedirect: case resp.StatusCode == http.StatusTemporaryRedirect:
w.Rollback()
b.nextURL <- nextURL b.nextURL <- nextURL
redirectURL, err := resp.Location() redirectURL, err := resp.Location()
@@ -259,14 +257,13 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
// retry uploading to the redirect URL // retry uploading to the redirect URL
for try := 0; try < maxRetries; try++ { for try := 0; try < maxRetries; try++ {
err = b.uploadChunk(ctx, http.MethodPut, redirectURL, part, nil) err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, nil)
switch { switch {
case errors.Is(err, context.Canceled): case errors.Is(err, context.Canceled):
return err return err
case errors.Is(err, errMaxRetriesExceeded): case errors.Is(err, errMaxRetriesExceeded):
return err return err
case err != nil: case err != nil:
part.Reset()
sleep := time.Second * time.Duration(math.Pow(2, float64(try))) sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep) log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
time.Sleep(sleep) time.Sleep(sleep)
@@ -279,6 +276,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err) return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
case resp.StatusCode == http.StatusUnauthorized: case resp.StatusCode == http.StatusUnauthorized:
w.Rollback()
auth := resp.Header.Get("www-authenticate") auth := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth) authRedir := ParseAuthRedirectString(auth)
token, err := getAuthToken(ctx, authRedir) token, err := getAuthToken(ctx, authRedir)
@@ -289,6 +287,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
opts.Token = token opts.Token = token
fallthrough fallthrough
case resp.StatusCode >= http.StatusBadRequest: case resp.StatusCode >= http.StatusBadRequest:
w.Rollback()
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return err return err
@@ -301,6 +300,7 @@ func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL
b.nextURL <- nextURL b.nextURL <- nextURL
} }
part.Hash = md5sum
return nil return nil
} }
@@ -341,22 +341,26 @@ func (b *blobUpload) Wait(ctx context.Context, fn func(api.ProgressResponse)) er
type blobUploadPart struct { type blobUploadPart struct {
// N is the part number // N is the part number
N int N int
Offset int64 Offset int64
Size int64 Size int64
hash.Hash
}
type progressWriter struct {
written int64 written int64
*blobUpload *blobUpload
} }
func (p *blobUploadPart) Write(b []byte) (n int, err error) { func (p *progressWriter) Write(b []byte) (n int, err error) {
n = len(b) n = len(b)
p.written += int64(n) p.written += int64(n)
p.Completed.Add(int64(n)) p.Completed.Add(int64(n))
return n, nil return n, nil
} }
func (p *blobUploadPart) Reset() { func (p *progressWriter) Rollback() {
p.Completed.Add(-int64(p.written)) p.Completed.Add(-p.written)
p.written = 0 p.written = 0
} }