Compare commits

...

74 Commits

Author SHA1 Message Date
Jeffrey Morgan
7db5bcf73b fix go-staticcheck warning 2023-12-10 11:44:27 -05:00
Jeffrey Morgan
fa2f095bd9 fix model name returned by /api/generate being different than the model name provided 2023-12-10 11:42:15 -05:00
Jeffrey Morgan
045b855db9 fix error on accumulating final chat response 2023-12-10 11:24:39 -05:00
Jeffrey Morgan
32064a0646 fix empty response when receiving runner error 2023-12-10 10:53:38 -05:00
Jeffrey Morgan
d9a250e9b5 seek to end of file when decoding older model formats 2023-12-09 21:14:35 -05:00
Jeffrey Morgan
944519ed16 seek to eof for older model binaries 2023-12-09 20:48:57 -05:00
Jeffrey Morgan
2dd040d04c do not use --parallel 2 for old runners 2023-12-09 20:17:33 -05:00
Bruce MacDonald
bbe41ce41a fix: parallel queueing race condition caused silent failure (#1445)
* fix: queued request failures

- increase parallel requests to 2 to complete queued request, queueing is managed in ollama

* log steam errors
2023-12-09 14:14:02 -05:00
Jeffrey Morgan
9e1406e4ed Don't expose model information in /api/generate 2023-12-09 02:05:43 -08:00
Jeffrey Morgan
b74580c913 Update api.md 2023-12-08 16:02:07 -08:00
Bruce MacDonald
7e9405fd07 fix: encode full previous prompt in context (#1424) 2023-12-08 16:53:51 -05:00
Bruce MacDonald
3b0b8930d4 fix: only flush template in chat when current role encountered (#1426) 2023-12-08 16:44:24 -05:00
Bruce MacDonald
e3f925fc1b fix: restore modelfile system in prompt template (#1425) 2023-12-08 14:20:19 -05:00
Jeffrey Morgan
2a2289fb6b Update api.md 2023-12-08 09:36:45 -08:00
Matt Williams
dd427f499a Merge pull request #1419 from jmorganca/mattw/typescript-simplechat
Simple chat example for typescript
2023-12-07 14:42:24 -08:00
Michael Yang
2ae573c7ed Merge pull request #1421 from jmorganca/mxyng/fix-newline
fix redundant newline
2023-12-07 13:47:23 -08:00
Matt Williams
02fe26c44b update the readme as per bruce
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-12-07 13:46:30 -08:00
Michael Yang
16c7548460 fix redundant newline 2023-12-07 13:44:45 -08:00
Matt Williams
fa75998c0d Update examples/typescript-simplechat/readme.md
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-07 13:40:54 -08:00
Matt Williams
5344f886c8 Update examples/typescript-simplechat/client.ts
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-07 13:40:37 -08:00
Matt Williams
6cc823c9b5 Update examples/typescript-simplechat/client.ts
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-07 13:39:59 -08:00
Matt Williams
b84d34e632 Update examples/typescript-simplechat/readme.md
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-07 13:39:33 -08:00
Matt Williams
30229a913c Update examples/typescript-simplechat/client.ts
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-07 13:39:24 -08:00
Matt Williams
1ade380bd7 Simple chat example for typescript
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-12-07 11:48:25 -08:00
Jeffrey Morgan
ba264e9da8 add future version note to chat api docs 2023-12-07 09:42:15 -08:00
Matt Williams
a2405ec831 Merge pull request #1409 from jmorganca/mattw/python-simplechat
Simple chat example
2023-12-06 15:49:45 -08:00
Matt Williams
ce809bb529 Merge branch 'mattw/python-simplechat' of github.com:jmorganca/ollama into mattw/python-simplechat 2023-12-06 15:48:42 -08:00
Matt Williams
76bc4d0458 Cleanup as per Bruce
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-12-06 15:44:40 -08:00
Bruce MacDonald
4a02945a15 Update examples/python-simplechat/client.py 2023-12-06 18:36:45 -05:00
Matt Williams
aec742b6d2 Update examples/python-simplechat/readme.md
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-06 15:30:45 -08:00
Matt Williams
f337642e94 Update examples/python-simplechat/readme.md
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-06 15:30:35 -08:00
Matt Williams
51131cc6e2 Update examples/python-simplechat/client.py
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2023-12-06 15:30:10 -08:00
Matt Williams
43027789dc Simple chat example
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-12-06 14:35:58 -08:00
Xe Iaso
f9b7d65e2b docs/tutorials: add bit on how to use Fly GPUs on-demand with Ollama (#1406)
Signed-off-by: Xe Iaso <xe@camellia.finch-kitefin.ts.net>
2023-12-06 14:14:02 -08:00
Michael Yang
1f05d77110 Merge pull request #1244 from jmorganca/brucemacd/no-fail-template
do not fail on unsupported template variables
2023-12-06 13:23:04 -08:00
Michael Yang
c3ff36088b Merge pull request #774 from jmorganca/mxyng/server-version
add version api and show server version in cli
2023-12-06 13:22:55 -08:00
Samuel Calderon
13524b5e72 List "Send chat messages" in table of contents (#1399)
Thank you @calderonsamuel
2023-12-06 12:34:27 -08:00
Michael Yang
f1b049fed8 Merge pull request #1377 from jmorganca/mxyng/qwen
update for qwen
2023-12-06 12:31:51 -08:00
Jeffrey Morgan
97c5696945 fix base urls in chat examples 2023-12-06 12:10:20 -08:00
Bruce MacDonald
47d4e22673 use missingkey in set empty interface when missing 2023-12-05 15:49:05 -08:00
Michael Yang
32f62fbb8e Merge pull request #1334 from jmorganca/mxyng/load-projectors
load projectors
2023-12-05 14:40:53 -08:00
Michael Yang
5d75505ebd return model configuration in generate 2023-12-05 14:39:02 -08:00
Michael Yang
b9495ea162 load projectors 2023-12-05 14:36:12 -08:00
Michael Yang
409bb9674e Merge pull request #1308 from jmorganca/mxyng/split-from
split from into one or more models
2023-12-05 14:33:03 -08:00
Michael Yang
d3479c07a1 Merge pull request #1250 from jmorganca/mxyng/create-layer
refactor layer creation
2023-12-05 14:32:52 -08:00
Michael Yang
b12f1b984f Merge pull request #1393 from jmorganca/mxyng/fix-whitespace
fix: trim space in modelfile fields
2023-12-05 12:18:01 -08:00
Bruce MacDonald
195e3d9dbd chat api endpoint (#1392) 2023-12-05 14:57:33 -05:00
Michael Yang
38fe1a368b fix: trim space in modelfile fields 2023-12-05 11:57:29 -08:00
Michael Yang
4b77fcb2b9 comments 2023-12-05 09:43:50 -08:00
Michael Yang
cde13bcdea cmd: only print server version when different 2023-12-05 09:36:01 -08:00
Michael Yang
0f0cd265a7 cmd: add server version 2023-12-05 09:36:01 -08:00
Michael Yang
0db4706ec2 api: add version api handler 2023-12-05 09:36:01 -08:00
Michael Yang
1ebdbd9694 server: add version handler 2023-12-05 09:36:01 -08:00
Michael Yang
5c59455b59 cmd: use existing cmd context 2023-12-05 09:36:01 -08:00
Jeffrey Morgan
00d06619a1 Revert "chat api (#991)" while context variable is fixed
This reverts commit 7a0899d62d.
2023-12-04 21:16:27 -08:00
Matt Williams
f1ef3f9947 remove mention of gpt-neox in import (#1381)
Signed-off-by: Matt Williams <m@technovangelist.com>
2023-12-04 20:58:10 -08:00
Michael Yang
5a5dca13b2 comments 2023-12-04 16:59:23 -08:00
Michael Yang
7232f1fa41 go mod tidy 2023-12-04 16:59:23 -08:00
Michael Yang
72e7a49aa9 seek instead of copyn 2023-12-04 16:59:23 -08:00
Michael Yang
a3737cbd33 use NewLayer for CreateBlobHandler 2023-12-04 16:59:23 -08:00
Michael Yang
998f1785b6 add modelfamilies 2023-12-04 16:59:23 -08:00
Michael Yang
70a93057cd refactor layer creation
previous layer creation was not ideal because:

1. it required reading the input file multiple times, once to calculate
   the sha256 checksum, another to write it to disk, and potentially one
   more to decode the underlying gguf
2. used io.ReadSeeker which is prone to user error. if the file isn't
   reset correctly or in the right place, it could end up reading an
   empty file

there are also some brittleness when reading existing layers else
writing the inherited layers will error reading an already closed file

this commit aims to fix these issues by restructuring layer creation.

1. it will now write the layer to a temporary file as well as the hash
   function and move it to the final location on Commit
2. layers are read once once when copied to the destination. exception
   is raw model files which still requires a second read to decode the
   model metadata
2023-12-04 16:59:23 -08:00
Michael Yang
2cb0fa7d40 split from into one or more models 2023-12-04 16:59:23 -08:00
Michael Yang
b2816bca67 unnecessary ReadSeeker for DecodeGGML 2023-12-04 16:59:23 -08:00
Patrick Devine
bf704423c5 revert cli to use /api/generate (#1383) 2023-12-04 16:35:29 -08:00
Bruce MacDonald
7a0899d62d chat api (#991)
- update chat docs
- add messages chat endpoint
- remove deprecated context and template generate parameters from docs
- context and template are still supported for the time being and will continue to work as expected
- add partial response to chat history
2023-12-04 18:01:06 -05:00
Michael Yang
0cca1486dd Merge pull request #1376 from jmorganca/mxyng/rocky-install
install: fix rocky kernel packages
2023-12-04 14:23:43 -08:00
Patrick Devine
2113c9d31a make linewrap still work when the terminal width has changed (#1350) 2023-12-04 14:14:56 -08:00
Michael Yang
6deebf2489 update for qwen 2023-12-04 11:38:05 -08:00
Michael Yang
95cb38ae47 install: fix rocky kernel packages 2023-12-04 11:10:42 -08:00
ruecat
1f126afb2d Ollama Telegram Bot (#1364)
* Add "ollama-telegram" to Extensions & Plugins

* Update README.md
2023-12-03 11:19:55 -08:00
Jeffrey Morgan
f6201a7a6c remove duplicate community integration in README.md 2023-12-02 21:18:13 -08:00
Michael Yang
b3f6c6598f Merge pull request #1349 from jmorganca/mxyng/ctrl-z
handle ctrl+z
2023-12-01 16:21:49 -08:00
Michael Yang
88620e983a handle ctrl+z 2023-12-01 16:15:20 -08:00
27 changed files with 1559 additions and 532 deletions

View File

@@ -214,14 +214,21 @@ curl http://localhost:11434/api/generate -d '{
}'
```
Or send a chat message (coming in 0.1.14):
```
curl http://localhost:11434/api/chat -d '{
"model": "mistral",
"messages": [
{ "role": "user", "content": "why is the sky blue?" }
]
}'
```
See the [API documentation](./docs/api.md) for all endpoints.
## Community Integrations
### Mobile
- [Mobile Artificial Intelligence Distribution](https://github.com/MaidFoundation/Maid) (Maid)
### Web & Desktop
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
@@ -277,6 +284,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Logseq Ollama plugin](https://github.com/omagdy7/ollama-logseq)
- [Dagger Chatbot](https://github.com/samalba/dagger-chatbot)
- [Discord AI Bot](https://github.com/mekb-turtle/discord-ai-bot)
- [Ollama Telegram Bot](https://github.com/ruecat/ollama-telegram)
- [Hass Ollama Conversation](https://github.com/ej52/hass-ollama-conversation)
- [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama)
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)

View File

@@ -221,6 +221,19 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate
})
}
type ChatResponseFunc func(ChatResponse) error
func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error {
return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error {
var resp ChatResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
return fn(resp)
})
}
type PullProgressFunc func(ProgressResponse) error
func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
@@ -311,3 +324,15 @@ func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) err
return nil
}
func (c *Client) Version(ctx context.Context) (string, error) {
var version struct {
Version string `json:"version"`
}
if err := c.do(ctx, http.MethodGet, "/api/version", nil, &version); err != nil {
return "", err
}
return version.Version, nil
}

View File

@@ -44,6 +44,39 @@ type GenerateRequest struct {
Options map[string]interface{} `json:"options"`
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream *bool `json:"stream,omitempty"`
Format string `json:"format"`
Options map[string]interface{} `json:"options"`
}
type Message struct {
Role string `json:"role"` // one of ["system", "user", "assistant"]
Content string `json:"content"`
}
type ChatResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Message *Message `json:"message,omitempty"`
Done bool `json:"done"`
Metrics
}
type Metrics struct {
TotalDuration time.Duration `json:"total_duration,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration time.Duration `json:"eval_duration,omitempty"`
}
// Options specfied in GenerateRequest, if you add a new option here add it to the API docs also
type Options struct {
Runner
@@ -173,39 +206,34 @@ type GenerateResponse struct {
Done bool `json:"done"`
Context []int `json:"context,omitempty"`
TotalDuration time.Duration `json:"total_duration,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration time.Duration `json:"eval_duration,omitempty"`
Metrics
}
func (r *GenerateResponse) Summary() {
if r.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration)
func (m *Metrics) Summary() {
if m.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
}
if r.LoadDuration > 0 {
fmt.Fprintf(os.Stderr, "load duration: %v\n", r.LoadDuration)
if m.LoadDuration > 0 {
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
}
if r.PromptEvalCount > 0 {
fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount)
if m.PromptEvalCount > 0 {
fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", m.PromptEvalCount)
}
if r.PromptEvalDuration > 0 {
fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", r.PromptEvalDuration)
fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(r.PromptEvalCount)/r.PromptEvalDuration.Seconds())
if m.PromptEvalDuration > 0 {
fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", m.PromptEvalDuration)
fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(m.PromptEvalCount)/m.PromptEvalDuration.Seconds())
}
if r.EvalCount > 0 {
fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", r.EvalCount)
if m.EvalCount > 0 {
fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", m.EvalCount)
}
if r.EvalDuration > 0 {
fmt.Fprintf(os.Stderr, "eval duration: %s\n", r.EvalDuration)
fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(r.EvalCount)/r.EvalDuration.Seconds())
if m.EvalDuration > 0 {
fmt.Fprintf(os.Stderr, "eval duration: %s\n", m.EvalDuration)
fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(m.EvalCount)/m.EvalDuration.Seconds())
}
}

View File

@@ -133,7 +133,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile)}
if err := client.Create(context.Background(), &request, fn); err != nil {
if err := client.Create(cmd.Context(), &request, fn); err != nil {
return err
}
@@ -148,7 +148,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
name := args[0]
// check if the model exists on the server
_, err = client.Show(context.Background(), &api.ShowRequest{Name: name})
_, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name})
var statusError api.StatusError
switch {
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
@@ -208,7 +208,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
}
request := api.PushRequest{Name: args[0], Insecure: insecure}
if err := client.Push(context.Background(), &request, fn); err != nil {
if err := client.Push(cmd.Context(), &request, fn); err != nil {
return err
}
@@ -222,7 +222,7 @@ func ListHandler(cmd *cobra.Command, args []string) error {
return err
}
models, err := client.List(context.Background())
models, err := client.List(cmd.Context())
if err != nil {
return err
}
@@ -257,7 +257,7 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
for _, name := range args {
req := api.DeleteRequest{Name: name}
if err := client.Delete(context.Background(), &req); err != nil {
if err := client.Delete(cmd.Context(), &req); err != nil {
return err
}
fmt.Printf("deleted '%s'\n", name)
@@ -322,7 +322,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
}
req := api.ShowRequest{Name: args[0]}
resp, err := client.Show(context.Background(), &req)
resp, err := client.Show(cmd.Context(), &req)
if err != nil {
return err
}
@@ -350,7 +350,7 @@ func CopyHandler(cmd *cobra.Command, args []string) error {
}
req := api.CopyRequest{Source: args[0], Destination: args[1]}
if err := client.Copy(context.Background(), &req); err != nil {
if err := client.Copy(cmd.Context(), &req); err != nil {
return err
}
fmt.Printf("copied '%s' to '%s'\n", args[0], args[1])
@@ -404,7 +404,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
}
request := api.PullRequest{Name: args[0], Insecure: insecure}
if err := client.Pull(context.Background(), &request, fn); err != nil {
if err := client.Pull(cmd.Context(), &request, fn); err != nil {
return err
}
@@ -493,7 +493,7 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
opts.WordWrap = false
}
cancelCtx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigChan := make(chan os.Signal, 1)
@@ -507,23 +507,22 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
var currentLineLength int
var wordBuffer string
request := api.GenerateRequest{
Model: opts.Model,
Prompt: opts.Prompt,
Context: generateContext,
Format: opts.Format,
System: opts.System,
Template: opts.Template,
Options: opts.Options,
}
fn := func(response api.GenerateResponse) error {
p.StopAndClear()
latest = response
if opts.WordWrap {
termWidth, _, _ = term.GetSize(int(os.Stdout.Fd()))
if opts.WordWrap && termWidth >= 10 {
for _, ch := range response.Response {
if currentLineLength+1 > termWidth-5 {
if len(wordBuffer) > termWidth-10 {
fmt.Printf("%s%c", wordBuffer, ch)
wordBuffer = ""
currentLineLength = 0
continue
}
// backtrack the length of the last word and clear to the end of the line
fmt.Printf("\x1b[%dD\x1b[K\n", len(wordBuffer))
fmt.Printf("%s%c", wordBuffer, ch)
@@ -543,13 +542,26 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
}
}
} else {
fmt.Print(response.Response)
fmt.Printf("%s%s", wordBuffer, response.Response)
if len(wordBuffer) > 0 {
wordBuffer = ""
}
}
return nil
}
if err := client.Generate(cancelCtx, &request, fn); err != nil {
request := api.GenerateRequest{
Model: opts.Model,
Prompt: opts.Prompt,
Context: generateContext,
Format: opts.Format,
System: opts.System,
Template: opts.Template,
Options: opts.Options,
}
if err := client.Generate(ctx, &request, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil
}
@@ -573,10 +585,7 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
latest.Summary()
}
ctx := cmd.Context()
ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
cmd.SetContext(ctx)
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context))
return nil
}
@@ -705,11 +714,11 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
case MultilineSystem:
opts.System = prompt
prompt = ""
fmt.Println("Set system template.\n")
fmt.Println("Set system template.")
case MultilineTemplate:
opts.Template = prompt
prompt = ""
fmt.Println("Set model template.\n")
fmt.Println("Set model template.")
}
multiline = MultilineNone
case strings.HasPrefix(line, `"""`) && len(prompt) == 0:
@@ -784,9 +793,9 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
if found {
opts.System = prompt
if args[1] == "system" {
fmt.Println("Set system template.\n")
fmt.Println("Set system template.")
} else {
fmt.Println("Set prompt template.\n")
fmt.Println("Set prompt template.")
}
prompt = ""
} else {
@@ -800,7 +809,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
}
} else {
opts.System = line
fmt.Println("Set system template.\n")
fmt.Println("Set system template.")
}
default:
fmt.Printf("Unknown command '/set %s'. Type /? for help\n", args[1])
@@ -977,7 +986,7 @@ func initializeKeypair() error {
return nil
}
func startMacApp(client *api.Client) error {
func startMacApp(ctx context.Context, client *api.Client) error {
exe, err := os.Executable()
if err != nil {
return err
@@ -1001,24 +1010,24 @@ func startMacApp(client *api.Client) error {
case <-timeout:
return errors.New("timed out waiting for server to start")
case <-tick:
if err := client.Heartbeat(context.Background()); err == nil {
if err := client.Heartbeat(ctx); err == nil {
return nil // server has started
}
}
}
}
func checkServerHeartbeat(_ *cobra.Command, _ []string) error {
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
if err := client.Heartbeat(context.Background()); err != nil {
if err := client.Heartbeat(cmd.Context()); err != nil {
if !strings.Contains(err.Error(), "connection refused") {
return err
}
if runtime.GOOS == "darwin" {
if err := startMacApp(client); err != nil {
if err := startMacApp(cmd.Context(), client); err != nil {
return fmt.Errorf("could not connect to ollama app, is it running?")
}
} else {
@@ -1028,8 +1037,29 @@ func checkServerHeartbeat(_ *cobra.Command, _ []string) error {
return nil
}
func versionHandler(cmd *cobra.Command, _ []string) {
client, err := api.ClientFromEnvironment()
if err != nil {
return
}
serverVersion, err := client.Version(cmd.Context())
if err != nil {
fmt.Println("Warning: could not connect to a running Ollama instance")
}
if serverVersion != "" {
fmt.Printf("ollama version is %s\n", serverVersion)
}
if serverVersion != version.Version {
fmt.Printf("Warning: client version is %s\n", version.Version)
}
}
func NewCLI() *cobra.Command {
log.SetFlags(log.LstdFlags | log.Lshortfile)
cobra.EnableCommandSorting = false
rootCmd := &cobra.Command{
Use: "ollama",
@@ -1039,10 +1069,17 @@ func NewCLI() *cobra.Command {
CompletionOptions: cobra.CompletionOptions{
DisableDefaultCmd: true,
},
Version: version.Version,
Run: func(cmd *cobra.Command, args []string) {
if version, _ := cmd.Flags().GetBool("version"); version {
versionHandler(cmd, args)
return
}
cmd.Print(cmd.UsageString())
},
}
cobra.EnableCommandSorting = false
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
createCmd := &cobra.Command{
Use: "create MODEL",

View File

@@ -24,7 +24,7 @@ All durations are returned in nanoseconds.
### Streaming responses
Certain endpoints stream responses as JSON objects delineated with the newline (`\n`) character.
Certain endpoints stream responses as JSON objects.
## Generate a completion
@@ -32,7 +32,7 @@ Certain endpoints stream responses as JSON objects delineated with the newline (
POST /api/generate
```
Generate a response for a given prompt with a provided model. This is a streaming endpoint, so will be a series of responses. The final response object will include statistics and additional data from the request.
Generate a response for a given prompt with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request.
### Parameters
@@ -47,7 +47,7 @@ Advanced parameters (optional):
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
- `raw`: if `true` no formatting will be applied to the prompt and no context will be returned. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API, and are managing history yourself.
- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API.
### JSON mode
@@ -114,6 +114,8 @@ To calculate how fast the response is generated in tokens per second (token/s),
#### Request (No streaming)
A response can be recieved in one reply when streaming is off.
```shell
curl http://localhost:11434/api/generate -d '{
"model": "llama2",
@@ -144,9 +146,9 @@ If `stream` is set to `false`, the response will be a single JSON object:
}
```
#### Request (Raw mode)
#### Request (Raw Mode)
In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting and context.
In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting.
```shell
curl http://localhost:11434/api/generate -d '{
@@ -164,6 +166,7 @@ curl http://localhost:11434/api/generate -d '{
"model": "mistral",
"created_at": "2023-11-03T15:36:02.583064Z",
"response": " The sky appears blue because of a phenomenon called Rayleigh scattering.",
"context": [1, 2, 3],
"done": true,
"total_duration": 14648695333,
"load_duration": 3302671417,
@@ -249,7 +252,7 @@ curl http://localhost:11434/api/generate -d '{
"penalize_newline": true,
"stop": ["\n", "user:"],
"numa": false,
"num_ctx": 4,
"num_ctx": 1024,
"num_batch": 2,
"num_gqa": 1,
"num_gpu": 1,
@@ -264,7 +267,7 @@ curl http://localhost:11434/api/generate -d '{
"rope_frequency_base": 1.1,
"rope_frequency_scale": 0.8,
"num_thread": 8
}
}
}'
```
@@ -275,7 +278,6 @@ curl http://localhost:11434/api/generate -d '{
"model": "llama2",
"created_at": "2023-08-04T19:22:45.499127Z",
"response": "The sky is blue because it is the color of the sky.",
"context": [1, 2, 3],
"done": true,
"total_duration": 5589157167,
"load_duration": 3013701500,
@@ -288,6 +290,136 @@ curl http://localhost:11434/api/generate -d '{
}
```
## Send Chat Messages (coming in 0.1.14)
```shell
POST /api/chat
```
Generate the next message in a chat with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request.
### Parameters
- `model`: (required) the [model name](#model-names)
- `messages`: the messages of the chat, this can be used to keep a chat memory
Advanced parameters (optional):
- `format`: the format to return a response in. Currently the only accepted value is `json`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
### Examples
#### Request
Send a chat message with a streaming response.
```shell
curl http://localhost:11434/api/chat -d '{
"model": "llama2",
"messages": [
{
"role": "user",
"content": "why is the sky blue?"
}
]
}'
```
#### Response
A stream of JSON objects is returned:
```json
{
"model": "llama2",
"created_at": "2023-08-04T08:52:19.385406455-07:00",
"message": {
"role": "assisant",
"content": "The"
},
"done": false
}
```
Final response:
```json
{
"model": "llama2",
"created_at": "2023-08-04T19:22:45.499127Z",
"done": true,
"total_duration": 5589157167,
"load_duration": 3013701500,
"sample_count": 114,
"sample_duration": 81442000,
"prompt_eval_count": 46,
"prompt_eval_duration": 1160282000,
"eval_count": 113,
"eval_duration": 1325948000
}
```
#### Request (With History)
Send a chat message with a conversation history.
```shell
curl http://localhost:11434/api/chat -d '{
"model": "llama2",
"messages": [
{
"role": "user",
"content": "why is the sky blue?"
},
{
"role": "assistant",
"content": "due to rayleigh scattering."
},
{
"role": "user",
"content": "how is that different than mie scattering?"
}
]
}'
```
#### Response
A stream of JSON objects is returned:
```json
{
"model": "llama2",
"created_at": "2023-08-04T08:52:19.385406455-07:00",
"message": {
"role": "assisant",
"content": "The"
},
"done": false
}
```
Final response:
```json
{
"model": "llama2",
"created_at": "2023-08-04T19:22:45.499127Z",
"done": true,
"total_duration": 5589157167,
"load_duration": 3013701500,
"sample_count": 114,
"sample_duration": 81442000,
"prompt_eval_count": 46,
"prompt_eval_duration": 1160282000,
"eval_count": 113,
"eval_duration": 1325948000
}
```
## Create a Model
```shell

View File

@@ -43,7 +43,6 @@ Ollama supports a set of model architectures, with support for more coming soon:
- Llama & Mistral
- Falcon & RW
- GPT-NeoX
- BigCode
To view a model's architecture, check the `config.json` file in its HuggingFace repo. You should see an entry under `architectures` (e.g. `LlamaForCausalLM`).
@@ -184,9 +183,6 @@ python convert.py <path to model directory>
# FalconForCausalLM
python convert-falcon-hf-to-gguf.py <path to model directory>
# GPTNeoXForCausalLM
python convert-gptneox-hf-to-gguf.py <path to model directory>
# GPTBigCodeForCausalLM
python convert-starcoder-hf-to-gguf.py <path to model directory>
```

83
docs/tutorials/fly-gpu.md Normal file
View File

@@ -0,0 +1,83 @@
# Running Ollama on Fly.io GPU Instances
Ollama runs with little to no configuration on [Fly.io GPU instances](https://fly.io/docs/gpus/gpu-quickstart/). If you don't have access to GPUs yet, you'll need to [apply for access](https://fly.io/gpu/) on the waitlist. Once you're accepted, you'll get an email with instructions on how to get started.
Create a new app with `fly apps create`:
```bash
fly apps create
```
Then create a `fly.toml` file in a new folder that looks like this:
```toml
app = "sparkling-violet-709"
primary_region = "ord"
vm.size = "a100-40gb" # see https://fly.io/docs/gpus/gpu-quickstart/ for more info
[build]
image = "ollama/ollama"
[http_service]
internal_port = 11434
force_https = false
auto_stop_machines = true
auto_start_machines = true
min_machines_running = 0
processes = ["app"]
[mounts]
source = "models"
destination = "/root/.ollama"
initial_size = "100gb"
```
Then create a [new private IPv6 address](https://fly.io/docs/reference/private-networking/#flycast-private-load-balancing) for your app:
```bash
fly ips allocate-v6 --private
```
Then deploy your app:
```bash
fly deploy
```
And finally you can access it interactively with a new Fly.io Machine:
```
fly machine run -e OLLAMA_HOST=http://your-app-name.flycast --shell ollama/ollama
```
```bash
$ ollama run openchat:7b-v3.5-fp16
>>> How do I bake chocolate chip cookies?
To bake chocolate chip cookies, follow these steps:
1. Preheat the oven to 375°F (190°C) and line a baking sheet with parchment paper or silicone baking mat.
2. In a large bowl, mix together 1 cup of unsalted butter (softened), 3/4 cup granulated sugar, and 3/4
cup packed brown sugar until light and fluffy.
3. Add 2 large eggs, one at a time, to the butter mixture, beating well after each addition. Stir in 1
teaspoon of pure vanilla extract.
4. In a separate bowl, whisk together 2 cups all-purpose flour, 1/2 teaspoon baking soda, and 1/2 teaspoon
salt. Gradually add the dry ingredients to the wet ingredients, stirring until just combined.
5. Fold in 2 cups of chocolate chips (or chunks) into the dough.
6. Drop rounded tablespoons of dough onto the prepared baking sheet, spacing them about 2 inches apart.
7. Bake for 10-12 minutes, or until the edges are golden brown. The centers should still be slightly soft.
8. Allow the cookies to cool on the baking sheet for a few minutes before transferring them to a wire rack
to cool completely.
Enjoy your homemade chocolate chip cookies!
```
When you set it up like this, it will automatically turn off when you're done using it. Then when you access it again, it will automatically turn back on. This is a great way to save money on GPU instances when you're not using them. If you want a persistent wake-on-use connection to your Ollama instance, you can set up a [connection to your Fly network using WireGuard](https://fly.io/docs/reference/private-networking/#discovering-apps-through-dns-on-a-wireguard-connection). Then you can access your Ollama instance at `http://your-app-name.flycast`.
And that's it!

View File

@@ -0,0 +1,46 @@
import json
import requests
# NOTE: ollama must be running for this to work, start the ollama app or run `ollama serve`
model = "llama2" # TODO: update this for whatever model you wish to use
def chat(messages):
r = requests.post(
"http://0.0.0.0:11434/api/chat",
json={"model": model, "messages": messages, "stream": True},
)
r.raise_for_status()
output = ""
for line in r.iter_lines():
body = json.loads(line)
if "error" in body:
raise Exception(body["error"])
if body.get("done") is False:
message = body.get("message", "")
content = message.get("content", "")
output += content
# the response streams one token at a time, print that as we receive it
print(content, end="", flush=True)
if body.get("done", False):
message["content"] = output
return message
def main():
messages = []
while True:
user_input = input("Enter a prompt: ")
print()
messages.append({"role": "user", "content": user_input})
message = chat(messages)
messages.append(message)
print("\n\n")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,24 @@
# Simple Chat Example
The **chat** endpoint is one of two ways to generate text from an LLM with Ollama. At a high level you provide the endpoint an array of objects with a role and content specified. Then with each output and prompt, you add more of those role/content objects, which builds up the history.
## Review the Code
You can see in the **chat** function that actually calling the endpoint is done simply with:
```python
r = requests.post(
"http://0.0.0.0:11434/api/chat",
json={"model": model, "messages": messages, "stream": True},
)
```
With the **generate** endpoint, you need to provide a `prompt`. But with **chat**, you provide `messages`. And the resulting stream of responses includes a `message` object with a `content` field.
The final JSON object doesn't provide the full content, so you will need to build the content yourself.
In the **main** function, we collect `user_input` and add it as a message to our messages and that is passed to the chat function. When the LLM is done responding the output is added as another message.
## Next Steps
In this example, all generations are kept. You might want to experiment with summarizing everything older than 10 conversations to enable longer history with less context being used.

View File

@@ -0,0 +1,77 @@
import * as readline from "readline";
const model = "llama2";
type Message = {
role: "assistant" | "user" | "system";
content: string;
}
const messages: Message[] = [{
role: "system",
content: "You are a helpful AI agent."
}]
const rl = readline.createInterface({
input: process.stdin,
output: process.stdout
})
async function chat(messages: Message[]): Promise<Message> {
const body = {
model: model,
messages: messages
}
const response = await fetch("http://localhost:11434/api/chat", {
method: "POST",
body: JSON.stringify(body)
})
const reader = response.body?.getReader()
if (!reader) {
throw new Error("Failed to read response body")
}
let content = ""
while (true) {
const { done, value } = await reader.read()
if (done) {
break;
}
const rawjson = new TextDecoder().decode(value);
const json = JSON.parse(rawjson)
if (json.done === false) {
process.stdout.write(json.message.content);
content += json.message.content
}
}
return { role: "assistant", content: content };
}
async function askQuestion(): Promise<void> {
return new Promise<void>((resolve) => {
rl.question("\n\nAsk a question: (press enter alone to quit)\n\n", async (user_input) => {
if (user_input.trim() === "") {
rl.close();
console.log("Thankyou. Goodbye.\n")
console.log("=======\nHere is the message history that was used in this conversation.\n=======\n")
messages.forEach(message => {
console.log(message)
})
resolve();
} else {
console.log();
messages.push({ role: "user", content: user_input });
messages.push(await chat(messages));
await askQuestion(); // Ask the next question
}
});
});
}
async function main() {
await askQuestion();
}
main();

View File

@@ -0,0 +1 @@
{ "dependencies": { "@types/node": "^20.10.4", "prompt-sync": "^4.2.0", "readline": "^1.3.0" } }

View File

@@ -0,0 +1,39 @@
# Simple Chat Example
The **chat** endpoint is one of two ways to generate text from an LLM with Ollama. At a high level you provide the endpoint an array of message objects with a role and content specified. Then with each output and prompt, you add more messages, which builds up the history.
## Run the Example
There are a few ways to run this, just like any Typescript code:
1. Compile with `tsc` and then run it with `node client.js`.
2. Install `tsx` and run it with `tsx client.ts`.
3. Install `bun` and run it with `bun client.ts`.
## Review the Code
You can see in the **chat** function that is actually calling the endpoint is simply done with:
```typescript
const body = {
model: model,
messages: messages
}
const response = await fetch("http://localhost:11434/api/chat", {
method: "POST",
body: JSON.stringify(body)
})
```
With the **generate** endpoint, you need to provide a `prompt`. But with **chat**, you provide `messages`. And the resulting stream of responses includes a `message` object with a `content` field.
The final JSON object doesn't provide the full content, so you will need to build the content yourself. In this example, **chat** takes the full array of messages and outputs the resulting message from this call of the chat endpoint.
In the **askQuestion** function, we collect `user_input` and add it as a message to our messages and that is passed to the chat function. When the LLM is done responding the output is added as another message to the messages array.
At the end, you will see a printout of all the messages.
## Next Steps
In this example, all generations are kept. You might want to experiment with summarizing everything older than 10 conversations to enable longer history with less context being used.

7
go.mod
View File

@@ -5,14 +5,15 @@ go 1.20
require (
github.com/emirpasic/gods v1.18.1
github.com/gin-gonic/gin v1.9.1
github.com/mattn/go-runewidth v0.0.14
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db
github.com/olekukonko/tablewriter v0.0.5
github.com/spf13/cobra v1.7.0
golang.org/x/sync v0.3.0
)
require github.com/rivo/uniseg v0.2.0 // indirect
require (
github.com/mattn/go-runewidth v0.0.14 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
)
require (
github.com/bytedance/sonic v1.9.1 // indirect

2
go.sum
View File

@@ -63,8 +63,6 @@ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ=
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=

View File

@@ -7,9 +7,10 @@ import (
)
type GGML struct {
magic uint32
container
model
Size int64
}
const (
@@ -82,7 +83,7 @@ type model interface {
type container interface {
Name() string
Decode(io.Reader) (model, error)
Decode(*readSeekOffset) (model, error)
}
type containerGGML struct{}
@@ -91,7 +92,9 @@ func (c *containerGGML) Name() string {
return "ggml"
}
func (c *containerGGML) Decode(r io.Reader) (model, error) {
func (c *containerGGML) Decode(ro *readSeekOffset) (model, error) {
// file contents aren't decoded
ro.Seek(0, io.SeekEnd)
return nil, nil
}
@@ -103,9 +106,9 @@ func (c *containerGGMF) Name() string {
return "ggmf"
}
func (c *containerGGMF) Decode(r io.Reader) (model, error) {
func (c *containerGGMF) Decode(ro *readSeekOffset) (model, error) {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
binary.Read(ro, binary.LittleEndian, &version)
switch version {
case 1:
@@ -114,6 +117,10 @@ func (c *containerGGMF) Decode(r io.Reader) (model, error) {
}
c.version = version
// remaining file contents aren't decoded
ro.Seek(0, io.SeekEnd)
return nil, nil
}
@@ -125,9 +132,9 @@ func (c *containerGGJT) Name() string {
return "ggjt"
}
func (c *containerGGJT) Decode(r io.Reader) (model, error) {
func (c *containerGGJT) Decode(ro *readSeekOffset) (model, error) {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
binary.Read(ro, binary.LittleEndian, &version)
switch version {
case 1, 2, 3:
@@ -139,7 +146,11 @@ func (c *containerGGJT) Decode(r io.Reader) (model, error) {
// different model types may have different layouts for hyperparameters
var llama llamaModel
binary.Read(r, binary.LittleEndian, &llama.hyperparameters)
binary.Read(ro, binary.LittleEndian, &llama.hyperparameters)
// remaining file contents aren't decoded
ro.Seek(0, io.SeekEnd)
return &llama, nil
}
@@ -151,9 +162,9 @@ func (c *containerLORA) Name() string {
return "ggla"
}
func (c *containerLORA) Decode(r io.Reader) (model, error) {
func (c *containerLORA) Decode(ro *readSeekOffset) (model, error) {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
binary.Read(ro, binary.LittleEndian, &version)
switch version {
case 1:
@@ -162,6 +173,10 @@ func (c *containerLORA) Decode(r io.Reader) (model, error) {
}
c.version = version
// remaining file contents aren't decoded
ro.Seek(0, io.SeekEnd)
return nil, nil
}
@@ -180,33 +195,61 @@ const (
)
func DecodeGGML(r io.ReadSeeker) (*GGML, error) {
var ggml GGML
binary.Read(r, binary.LittleEndian, &ggml.magic)
ro := readSeekOffset{ReadSeeker: r}
switch ggml.magic {
var magic uint32
if err := binary.Read(&ro, binary.LittleEndian, &magic); err != nil {
return nil, err
}
var c container
switch magic {
case FILE_MAGIC_GGML:
ggml.container = &containerGGML{}
c = &containerGGML{}
case FILE_MAGIC_GGMF:
ggml.container = &containerGGMF{}
c = &containerGGMF{}
case FILE_MAGIC_GGJT:
ggml.container = &containerGGJT{}
c = &containerGGJT{}
case FILE_MAGIC_GGLA:
ggml.container = &containerLORA{}
c = &containerLORA{}
case FILE_MAGIC_GGUF_LE:
ggml.container = &containerGGUF{bo: binary.LittleEndian}
c = &containerGGUF{bo: binary.LittleEndian}
case FILE_MAGIC_GGUF_BE:
ggml.container = &containerGGUF{bo: binary.BigEndian}
c = &containerGGUF{bo: binary.BigEndian}
default:
return nil, errors.New("invalid file magic")
}
model, err := ggml.Decode(r)
model, err := c.Decode(&ro)
if err != nil {
return nil, err
}
ggml.model = model
// final model type
return &ggml, nil
return &GGML{
container: c,
model: model,
Size: ro.offset,
}, nil
}
type readSeekOffset struct {
io.ReadSeeker
offset int64
}
func (rso *readSeekOffset) Seek(offset int64, whence int) (int64, error) {
offset, err := rso.ReadSeeker.Seek(offset, whence)
if err != nil {
return 0, err
}
rso.offset = offset
return offset, nil
}
func (rso *readSeekOffset) Read(p []byte) (int, error) {
n, err := rso.ReadSeeker.Read(p)
rso.offset += int64(n)
return n, err
}

View File

@@ -23,26 +23,24 @@ type containerGGUF struct {
NumTensor uint64
NumKV uint64
}
parameters uint64
}
func (c *containerGGUF) Name() string {
return "gguf"
}
func (c *containerGGUF) Decode(r io.Reader) (model, error) {
binary.Read(r, c.bo, &c.Version)
func (c *containerGGUF) Decode(rso *readSeekOffset) (model, error) {
binary.Read(rso, c.bo, &c.Version)
switch c.Version {
case 1:
binary.Read(r, c.bo, &c.V1)
binary.Read(rso, c.bo, &c.V1)
default:
binary.Read(r, c.bo, &c.V2)
binary.Read(rso, c.bo, &c.V2)
}
model := newGGUFModel(c)
if err := model.Decode(r); err != nil {
if err := model.Decode(rso); err != nil {
return nil, err
}
@@ -67,9 +65,23 @@ const (
type kv map[string]any
type tensor struct {
name string
kind uint32
offset uint64
size uint64
// shape is the number of elements in each dimension
shape [4]uint64
}
type ggufModel struct {
*containerGGUF
kv
tensors []tensor
parameters uint64
}
func newGGUFModel(container *containerGGUF) *ggufModel {
@@ -96,8 +108,7 @@ func (llm *ggufModel) NumKV() uint64 {
}
func (llm *ggufModel) ModelFamily() string {
t, ok := llm.kv["general.architecture"].(string)
if ok {
if t, ok := llm.kv["general.architecture"].(string); ok {
return t
}
@@ -134,57 +145,56 @@ func (llm *ggufModel) ModelType() string {
}
func (llm *ggufModel) FileType() string {
t, ok := llm.kv["general.file_type"].(uint32)
if ok {
if t, ok := llm.kv["general.file_type"].(uint32); ok {
return fileType(t)
}
return "unknown"
}
func (llm *ggufModel) Decode(r io.Reader) error {
func (llm *ggufModel) Decode(rso *readSeekOffset) error {
// decode key-values
for i := 0; uint64(i) < llm.NumKV(); i++ {
k, err := llm.readString(r)
k, err := llm.readString(rso)
if err != nil {
return err
}
vtype := llm.readU32(r)
vtype := llm.readU32(rso)
var v any
switch vtype {
case ggufTypeUint8:
v = llm.readU8(r)
v = llm.readU8(rso)
case ggufTypeInt8:
v = llm.readI8(r)
v = llm.readI8(rso)
case ggufTypeUint16:
v = llm.readU16(r)
v = llm.readU16(rso)
case ggufTypeInt16:
v = llm.readI16(r)
v = llm.readI16(rso)
case ggufTypeUint32:
v = llm.readU32(r)
v = llm.readU32(rso)
case ggufTypeInt32:
v = llm.readI32(r)
v = llm.readI32(rso)
case ggufTypeUint64:
v = llm.readU64(r)
v = llm.readU64(rso)
case ggufTypeInt64:
v = llm.readI64(r)
v = llm.readI64(rso)
case ggufTypeFloat32:
v = llm.readF32(r)
v = llm.readF32(rso)
case ggufTypeFloat64:
v = llm.readF64(r)
v = llm.readF64(rso)
case ggufTypeBool:
v = llm.readBool(r)
v = llm.readBool(rso)
case ggufTypeString:
s, err := llm.readString(r)
s, err := llm.readString(rso)
if err != nil {
return err
}
v = s
case ggufTypeArray:
a, err := llm.readArray(r)
a, err := llm.readArray(rso)
if err != nil {
return err
}
@@ -199,21 +209,85 @@ func (llm *ggufModel) Decode(r io.Reader) error {
// decode tensors
for i := 0; uint64(i) < llm.NumTensor(); i++ {
if _, err := llm.readString(r); err != nil {
name, err := llm.readString(rso)
if err != nil {
return err
}
dimensions := llm.readU32(r)
// dims is the number of dimensions in the tensor
dims := llm.readU32(rso)
var elements uint64 = 1
for i := 0; uint32(i) < dimensions; i++ {
elements *= llm.readU64(r)
shape := [4]uint64{1, 1, 1, 1}
for i := 0; uint32(i) < dims; i++ {
shape[i] = llm.readU64(rso)
}
llm.readU32(r) // type
llm.readU64(r) // offset
kind := llm.readU32(rso)
offset := llm.readU64(rso)
llm.parameters += elements
var blockSize uint64
switch {
case kind < 2:
blockSize = 1
case kind < 10:
blockSize = 32
default:
blockSize = 256
}
var typeSize uint64
switch kind {
case 0: // FP32
typeSize = 4
case 1: // FP16
typeSize = 2
case 2: // Q4_0
typeSize = 2 + blockSize/2
case 3: // Q4_1
typeSize = 2 + 2 + blockSize/2
case 6: // Q5_0
typeSize = 2 + 4 + blockSize/2
case 7: // Q5_1
typeSize = 2 + 2 + 4 + blockSize/2
case 8: // Q8_0
typeSize = 2 + blockSize
case 9: // Q8_1
typeSize = 4 + 4 + blockSize
case 10: // Q2_K
typeSize = blockSize/16 + blockSize/4 + 2 + 2
case 11: // Q3_K
typeSize = blockSize/8 + blockSize/4 + 12 + 2
case 12: // Q4_K
typeSize = 2 + 2 + 12 + blockSize/2
case 13: // Q5_K
typeSize = 2 + 2 + 12 + blockSize/8 + blockSize/2
case 14: // Q6_K
typeSize = blockSize/2 + blockSize/4 + blockSize/16 + 2
}
parameters := shape[0] * shape[1] * shape[2] * shape[3]
size := parameters * typeSize / blockSize
llm.tensors = append(llm.tensors, tensor{
name: name,
kind: kind,
offset: offset,
size: size,
shape: shape,
})
llm.parameters += parameters
}
alignment, ok := llm.kv["general.alignment"].(uint32)
if !ok {
alignment = 32
}
rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent)
for _, tensor := range llm.tensors {
padded := (int64(tensor.size) + int64(alignment) - 1) & ^(int64(alignment) - 1)
rso.Seek(padded, io.SeekCurrent)
}
return nil

View File

@@ -59,6 +59,7 @@ ws ::= ([ \t\n] ws)?
var llamaCppEmbed embed.FS
type ModelRunner struct {
Type string // "gguf" or "ggml"
Path string // path to the model runner executable
Accelerated bool
}
@@ -72,25 +73,25 @@ func chooseRunners(workDir, runnerType string) []ModelRunner {
switch runtime.GOOS {
case "darwin":
if runtime.GOARCH == "arm64" {
runners = []ModelRunner{{Path: path.Join(buildPath, "metal", "bin", "ollama-runner")}}
runners = []ModelRunner{{Type: runnerType, Path: path.Join(buildPath, "metal", "bin", "ollama-runner")}}
} else {
runners = []ModelRunner{{Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")}}
runners = []ModelRunner{{Type: runnerType, Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")}}
}
case "linux":
runners = []ModelRunner{
{Path: path.Join(buildPath, "cuda", "bin", "ollama-runner"), Accelerated: true},
{Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
{Type: runnerType, Path: path.Join(buildPath, "cuda", "bin", "ollama-runner"), Accelerated: true},
{Type: runnerType, Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
}
case "windows":
// TODO: select windows GPU runner here when available
runners = []ModelRunner{
{Path: path.Join(buildPath, "cuda", "bin", "Release", "ollama-runner.exe"), Accelerated: true},
{Path: path.Join(buildPath, "cpu", "bin", "Release", "ollama-runner.exe")},
{Type: runnerType, Path: path.Join(buildPath, "cuda", "bin", "Release", "ollama-runner.exe"), Accelerated: true},
{Type: runnerType, Path: path.Join(buildPath, "cpu", "bin", "Release", "ollama-runner.exe")},
}
default:
log.Printf("unknown OS, running on CPU: %s", runtime.GOOS)
runners = []ModelRunner{
{Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
{Type: runnerType, Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
}
}
@@ -148,6 +149,7 @@ func chooseRunners(workDir, runnerType string) []ModelRunner {
for _, r := range runners {
// clean the ModelRunner paths so that they match the OS we are running on
localRunnersByPriority = append(localRunnersByPriority, ModelRunner{
Type: r.Type,
Path: filepath.Clean(path.Join(workDir, r.Path)),
Accelerated: r.Accelerated,
})
@@ -325,7 +327,7 @@ func (w *StatusWriter) Write(b []byte) (int, error) {
return os.Stderr.Write(b)
}
func newLlama(model string, adapters []string, runners []ModelRunner, numLayers int64, opts api.Options) (*llama, error) {
func newLlama(model string, adapters, projectors []string, runners []ModelRunner, numLayers int64, opts api.Options) (*llama, error) {
fileInfo, err := os.Stat(model)
if err != nil {
return nil, err
@@ -365,6 +367,11 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
params = append(params, "--lora", adapters[0])
}
if len(projectors) > 0 {
// TODO: applying multiple projectors is not supported by the llama.cpp server yet
params = append(params, "--mmproj", projectors[0])
}
if opts.NumThread > 0 {
params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread))
}
@@ -397,11 +404,17 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
}
port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
params := append(params, "--port", strconv.Itoa(port))
if runner.Type == "gguf" {
params = append(params, "--parallel", "2")
}
ctx, cancel := context.WithCancel(context.Background())
cmd := exec.CommandContext(
ctx,
runner.Path,
append(params, "--port", strconv.Itoa(port))...,
params...,
)
var libraryPaths []string
@@ -531,21 +544,28 @@ type prediction struct {
const maxBufferSize = 512 * format.KiloByte
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error {
prevConvo, err := llm.Decode(ctx, prevContext)
if err != nil {
return err
}
type PredictOpts struct {
Prompt string
Format string
CheckpointStart time.Time
CheckpointLoaded time.Time
}
// Remove leading spaces from prevConvo if present
prevConvo = strings.TrimPrefix(prevConvo, " ")
var nextContext strings.Builder
nextContext.WriteString(prevConvo)
nextContext.WriteString(prompt)
type PredictResult struct {
CreatedAt time.Time
TotalDuration time.Duration
LoadDuration time.Duration
Content string
Done bool
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
}
func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
request := map[string]any{
"prompt": nextContext.String(),
"prompt": predict.Prompt,
"stream": true,
"n_predict": llm.NumPredict,
"n_keep": llm.NumKeep,
@@ -567,7 +587,7 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
"stop": llm.Stop,
}
if format == "json" {
if predict.Format == "json" {
request["grammar"] = jsonGrammar
}
@@ -617,34 +637,35 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
continue
}
if evt, ok := bytes.CutPrefix(line, []byte("data: ")); ok {
var p prediction
if err := json.Unmarshal(evt, &p); err != nil {
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
}
evt, ok := bytes.CutPrefix(line, []byte("data: "))
if !ok {
return fmt.Errorf("error parsing llm response stream: %s", line)
}
if p.Content != "" {
fn(api.GenerateResponse{Response: p.Content})
nextContext.WriteString(p.Content)
}
var p prediction
if err := json.Unmarshal(evt, &p); err != nil {
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
}
if p.Stop {
embd, err := llm.Encode(ctx, nextContext.String())
if err != nil {
return fmt.Errorf("encoding context: %v", err)
}
if p.Content != "" {
fn(PredictResult{
CreatedAt: time.Now().UTC(),
Content: p.Content,
})
}
fn(api.GenerateResponse{
Done: true,
Context: embd,
PromptEvalCount: p.Timings.PromptN,
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
EvalCount: p.Timings.PredictedN,
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
})
if p.Stop {
fn(PredictResult{
CreatedAt: time.Now().UTC(),
TotalDuration: time.Since(predict.CheckpointStart),
return nil
}
Done: true,
PromptEvalCount: p.Timings.PromptN,
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
EvalCount: p.Timings.PredictedN,
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
})
return nil
}
}
}

View File

@@ -14,7 +14,7 @@ import (
)
type LLM interface {
Predict(context.Context, []int, string, string, func(api.GenerateResponse)) error
Predict(context.Context, PredictOpts, func(PredictResult)) error
Embedding(context.Context, string) ([]float64, error)
Encode(context.Context, string) ([]int, error)
Decode(context.Context, []int) (string, error)
@@ -23,7 +23,7 @@ type LLM interface {
Ping(context.Context) error
}
func New(workDir, model string, adapters []string, opts api.Options) (LLM, error) {
func New(workDir, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
if _, err := os.Stat(model); err != nil {
return nil, err
}
@@ -82,9 +82,9 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error
opts.NumGQA = 0
opts.RopeFrequencyBase = 0.0
opts.RopeFrequencyScale = 0.0
return newLlama(model, adapters, chooseRunners(workDir, "gguf"), ggml.NumLayers(), opts)
return newLlama(model, adapters, projectors, chooseRunners(workDir, "gguf"), ggml.NumLayers(), opts)
case "ggml", "ggmf", "ggjt", "ggla":
return newLlama(model, adapters, chooseRunners(workDir, "ggml"), ggml.NumLayers(), opts)
return newLlama(model, adapters, projectors, chooseRunners(workDir, "ggml"), ggml.NumLayers(), opts)
default:
return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily())
}

View File

@@ -37,10 +37,13 @@ func Parse(reader io.Reader) ([]Command, error) {
switch string(bytes.ToUpper(fields[0])) {
case "FROM":
command.Name = "model"
command.Args = string(fields[1])
command.Args = string(bytes.TrimSpace(fields[1]))
// copy command for validation
modelCommand = command
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "ADAPTER":
case "ADAPTER":
command.Name = string(bytes.ToLower(fields[0]))
command.Args = string(bytes.TrimSpace(fields[1]))
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
command.Name = string(bytes.ToLower(fields[0]))
command.Args = string(fields[1])
case "PARAMETER":
@@ -50,7 +53,7 @@ func Parse(reader io.Reader) ([]Command, error) {
}
command.Name = string(fields[0])
command.Args = string(fields[1])
command.Args = string(bytes.TrimSpace(fields[1]))
case "EMBED":
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
default:

View File

@@ -191,6 +191,15 @@ func (i *Instance) Readline() (string, error) {
buf.ClearScreen()
case CharCtrlW:
buf.DeleteWord()
case CharCtrlZ:
if err := UnsetRawMode(fd, termios); err != nil {
return "", err
}
syscall.Kill(0, syscall.SIGSTOP)
// on resume...
return "", nil
case CharEnter:
output := buf.String()
if output != "" {

View File

@@ -217,7 +217,7 @@ fi
if ! check_gpu nvidia-smi || [ -z "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then
case $OS_NAME in
centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;;
centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;;
rocky) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -c1) ;;
fedora) install_cuda_driver_yum $OS_NAME $OS_VERSION ;;
amzn) install_cuda_driver_yum 'fedora' '35' ;;
@@ -230,7 +230,8 @@ fi
if ! lsmod | grep -q nvidia; then
KERNEL_RELEASE="$(uname -r)"
case $OS_NAME in
centos|rhel|rocky|amzn) $SUDO $PACKAGE_MANAGER -y install kernel-devel-$KERNEL_RELEASE kernel-headers-$KERNEL_RELEASE ;;
rocky) $SUDO $PACKAGE_MANAGER -y install kernel-devel kernel-headers ;;
centos|rhel|amzn) $SUDO $PACKAGE_MANAGER -y install kernel-devel-$KERNEL_RELEASE kernel-headers-$KERNEL_RELEASE ;;
fedora) $SUDO $PACKAGE_MANAGER -y install kernel-devel-$KERNEL_RELEASE ;;
debian|ubuntu) $SUDO apt-get -y install linux-headers-$KERNEL_RELEASE ;;
*) exit ;;

View File

@@ -35,80 +35,157 @@ type RegistryOptions struct {
}
type Model struct {
Name string `json:"name"`
ShortName string
ModelPath string
OriginalModel string
AdapterPaths []string
Template string
System string
License []string
Digest string
Options map[string]interface{}
Name string `json:"name"`
Config ConfigV2
ShortName string
ModelPath string
OriginalModel string
AdapterPaths []string
ProjectorPaths []string
Template string
System string
License []string
Digest string
Options map[string]interface{}
}
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
t := m.Template
if request.Template != "" {
t = request.Template
}
type PromptVars struct {
System string
Prompt string
Response string
First bool
}
tmpl, err := template.New("").Parse(t)
func (m *Model) Prompt(p PromptVars) (string, error) {
var prompt strings.Builder
// Use the "missingkey=zero" option to handle missing variables without panicking
tmpl, err := template.New("").Option("missingkey=zero").Parse(m.Template)
if err != nil {
return "", err
}
var vars struct {
First bool
System string
Prompt string
if p.System == "" {
// use the default system prompt for this model if one is not specified
p.System = m.System
}
vars.First = len(request.Context) == 0
vars.System = m.System
vars.Prompt = request.Prompt
if request.System != "" {
vars.System = request.System
vars := map[string]any{
"System": p.System,
"Prompt": p.Prompt,
"Response": p.Response,
"First": p.First,
}
var sb strings.Builder
if err := tmpl.Execute(&sb, vars); err != nil {
return "", err
}
prompt.WriteString(sb.String())
prompt.WriteString(p.Response)
return prompt.String(), nil
}
return sb.String(), nil
func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
// build the prompt from the list of messages
var prompt strings.Builder
currentVars := PromptVars{
First: true,
}
writePrompt := func() error {
p, err := m.Prompt(currentVars)
if err != nil {
return err
}
prompt.WriteString(p)
currentVars = PromptVars{}
return nil
}
for _, msg := range msgs {
switch strings.ToLower(msg.Role) {
case "system":
if currentVars.System != "" {
if err := writePrompt(); err != nil {
return "", err
}
}
currentVars.System = msg.Content
case "user":
if currentVars.Prompt != "" {
if err := writePrompt(); err != nil {
return "", err
}
}
currentVars.Prompt = msg.Content
case "assistant":
currentVars.Response = msg.Content
if err := writePrompt(); err != nil {
return "", err
}
default:
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
}
}
// Append the last set of vars if they are non-empty
if currentVars.Prompt != "" || currentVars.System != "" {
if err := writePrompt(); err != nil {
return "", err
}
}
return prompt.String(), nil
}
type ManifestV2 struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config Layer `json:"config"`
Config *Layer `json:"config"`
Layers []*Layer `json:"layers"`
}
type Layer struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int64 `json:"size"`
From string `json:"from,omitempty"`
}
type LayerReader struct {
Layer
io.Reader
}
type ConfigV2 struct {
ModelFormat string `json:"model_format"`
ModelFamily string `json:"model_family"`
ModelType string `json:"model_type"`
FileType string `json:"file_type"`
RootFS RootFS `json:"rootfs"`
ModelFormat string `json:"model_format"`
ModelFamily string `json:"model_family"`
ModelFamilies []string `json:"model_families"`
ModelType string `json:"model_type"`
FileType string `json:"file_type"`
// required by spec
Architecture string `json:"architecture"`
OS string `json:"os"`
RootFS RootFS `json:"rootfs"`
}
func (c *ConfigV2) SetModelFormat(format string) {
if c.ModelFormat == "" {
c.ModelFormat = format
}
}
func (c *ConfigV2) SetModelFamily(families ...string) {
for _, family := range families {
if c.ModelFamily == "" {
c.ModelFamily = family
}
if !slices.Contains(c.ModelFamilies, family) {
c.ModelFamilies = append(c.ModelFamilies, family)
}
}
}
func (c *ConfigV2) SetModelType(modelType string) {
if c.ModelType == "" {
c.ModelType = modelType
}
}
func (c *ConfigV2) SetFileType(fileType string) {
if c.FileType == "" {
c.FileType = fileType
}
}
type RootFS struct {
@@ -167,6 +244,21 @@ func GetModel(name string) (*Model, error) {
License: []string{},
}
filename, err := GetBlobsPath(manifest.Config.Digest)
if err != nil {
return nil, err
}
configFile, err := os.Open(filename)
if err != nil {
return nil, err
}
defer configFile.Close()
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
return nil, err
}
for _, layer := range manifest.Layers {
filename, err := GetBlobsPath(layer.Digest)
if err != nil {
@@ -183,6 +275,8 @@ func GetModel(name string) (*Model, error) {
log.Print("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
case "application/vnd.ollama.image.adapter":
model.AdapterPaths = append(model.AdapterPaths, filename)
case "application/vnd.ollama.image.projector":
model.ProjectorPaths = append(model.ProjectorPaths, filename)
case "application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename)
if err != nil {
@@ -256,11 +350,14 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
config := ConfigV2{
OS: "linux",
Architecture: "amd64",
RootFS: RootFS{
Type: "layers",
},
}
deleteMap := make(map[string]struct{})
var layers []*LayerReader
var layers Layers
params := make(map[string][]string)
fromParams := make(map[string]any)
@@ -317,10 +414,10 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
return err
}
config.ModelFormat = fromConfig.ModelFormat
config.ModelFamily = fromConfig.ModelFamily
config.ModelType = fromConfig.ModelType
config.FileType = fromConfig.FileType
config.SetModelFormat(fromConfig.ModelFormat)
config.SetModelFamily(append(fromConfig.ModelFamilies, fromConfig.ModelFamily)...)
config.SetModelType(fromConfig.ModelType)
config.SetFileType(fromConfig.FileType)
for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = struct{}{}
@@ -341,13 +438,12 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
}
}
layer, err := GetLayerWithBufferFromLayer(layer)
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
if err != nil {
return err
}
layer.From = modelpath.GetShortTagname()
layers = append(layers, layer)
layers.Add(layer)
}
deleteMap[manifest.Config.Digest] = struct{}{}
@@ -355,25 +451,38 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
}
defer bin.Close()
fn(api.ProgressResponse{Status: "creating model layer"})
ggml, err := llm.DecodeGGML(bin)
if err != nil {
return err
var offset int64
for {
fn(api.ProgressResponse{Status: "creating model layer"})
bin.Seek(offset, io.SeekStart)
ggml, err := llm.DecodeGGML(bin)
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return err
}
config.SetModelFormat(ggml.Name())
config.SetModelFamily(ggml.ModelFamily())
config.SetModelType(ggml.ModelType())
config.SetFileType(ggml.FileType())
mediatype := mediatype
if ggml.ModelFamily() == "clip" {
mediatype = "application/vnd.ollama.image.projector"
}
sr := io.NewSectionReader(bin, offset, ggml.Size)
layer, err := NewLayer(sr, mediatype)
if err != nil {
return err
}
layers.Add(layer)
offset += ggml.Size
}
config.ModelFormat = ggml.Name()
config.ModelFamily = ggml.ModelFamily()
config.ModelType = ggml.ModelType()
config.FileType = ggml.FileType()
bin.Seek(0, io.SeekStart)
layer, err := CreateLayer(bin)
if err != nil {
return err
}
layer.MediaType = mediatype
layers = append(layers, layer)
case "adapter":
if strings.HasPrefix(c.Args, "@") {
blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
@@ -383,7 +492,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
c.Args = blobPath
}
fn(api.ProgressResponse{Status: "creating adapter layer"})
bin, err := os.Open(realpath(modelFileDir, c.Args))
if err != nil {
@@ -391,41 +500,32 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
}
defer bin.Close()
layer, err := CreateLayer(bin)
layer, err := NewLayer(bin, mediatype)
if err != nil {
return err
}
if layer.Size > 0 {
layer.MediaType = mediatype
layers = append(layers, layer)
}
layers.Add(layer)
case "license":
fn(api.ProgressResponse{Status: "creating license layer"})
layer, err := CreateLayer(strings.NewReader(c.Args))
bin := strings.NewReader(c.Args)
layer, err := NewLayer(bin, mediatype)
if err != nil {
return err
}
if layer.Size > 0 {
layer.MediaType = mediatype
layers = append(layers, layer)
}
layers.Add(layer)
case "template", "system":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)})
// remove duplicate layers
layers = removeLayerFromLayers(layers, mediatype)
layer, err := CreateLayer(strings.NewReader(c.Args))
bin := strings.NewReader(c.Args)
layer, err := NewLayer(bin, mediatype)
if err != nil {
return err
}
if layer.Size > 0 {
layer.MediaType = mediatype
layers = append(layers, layer)
}
layers.Replace(layer)
default:
params[c.Name] = append(params[c.Name], c.Args)
}
@@ -457,40 +557,51 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
}
fn(api.ProgressResponse{Status: "creating config layer"})
layer, err := CreateLayer(bytes.NewReader(b.Bytes()))
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return err
}
layer.MediaType = "application/vnd.ollama.image.params"
layers = append(layers, layer)
layers.Replace(layer)
}
digests, err := getLayerDigests(layers)
digests := make([]string, len(layers.items))
for i, layer := range layers.items {
digests[i] = layer.Digest
}
config.RootFS.DiffIDs = digests
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(config); err != nil {
return err
}
configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
return err
}
configLayer, err := createConfigLayer(config, digests)
if err != nil {
return err
}
layers = append(layers, configLayer)
delete(deleteMap, configLayer.Digest)
if err := SaveLayers(layers, fn, false); err != nil {
return err
}
for _, layer := range append(layers.items, configLayer) {
committed, err := layer.Commit()
if err != nil {
return err
}
status := "writing layer"
if !committed {
status = "using already created layer"
}
fn(api.ProgressResponse{Status: fmt.Sprintf("%s %s", status, layer.Digest)})
var contentLayers []*Layer
for _, layer := range layers {
contentLayers = append(contentLayers, &layer.Layer)
delete(deleteMap, layer.Digest)
}
fn(api.ProgressResponse{Status: "writing manifest"})
if err := CreateManifest(name, configLayer, contentLayers); err != nil {
if err := WriteManifest(name, configLayer, layers.items); err != nil {
return err
}
@@ -504,119 +615,6 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
return nil
}
func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerReader {
return slices.DeleteFunc(layers, func(layer *LayerReader) bool {
return layer.MediaType == mediaType
})
}
func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error {
// Write each of the layers to disk
for _, layer := range layers {
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
_, err = os.Stat(fp)
if os.IsNotExist(err) || force {
fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)})
out, err := os.Create(fp)
if err != nil {
log.Printf("couldn't create %s", fp)
return err
}
defer out.Close()
if _, err = io.Copy(out, layer.Reader); err != nil {
return err
}
} else {
fn(api.ProgressResponse{Status: fmt.Sprintf("using already created layer %s", layer.Digest)})
}
}
return nil
}
func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
mp := ParseModelPath(name)
manifest := ManifestV2{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: Layer{
MediaType: cfg.MediaType,
Size: cfg.Size,
Digest: cfg.Digest,
},
Layers: layers,
}
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
return err
}
return os.WriteFile(fp, manifestJSON, 0o644)
}
func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
file, err := os.Open(fp)
if err != nil {
return nil, fmt.Errorf("could not open blob: %w", err)
}
defer file.Close()
newLayer, err := CreateLayer(file)
if err != nil {
return nil, err
}
newLayer.MediaType = layer.MediaType
return newLayer, nil
}
func getLayerDigests(layers []*LayerReader) ([]string, error) {
var digests []string
for _, l := range layers {
if l.Digest == "" {
return nil, fmt.Errorf("layer is missing a digest")
}
digests = append(digests, l.Digest)
}
return digests, nil
}
// CreateLayer creates a Layer object from a given file
func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
digest, size := GetSHA256Digest(f)
f.Seek(0, io.SeekStart)
layer := &LayerReader{
Layer: Layer{
MediaType: "application/vnd.docker.image.rootfs.diff.tar",
Digest: digest,
Size: size,
},
Reader: f,
}
return layer, nil
}
func CopyModel(src, dest string) error {
srcModelPath := ParseModelPath(src)
srcPath, err := srcModelPath.GetManifestPath()
@@ -884,7 +882,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
var layers []*Layer
layers = append(layers, manifest.Layers...)
layers = append(layers, &manifest.Config)
layers = append(layers, manifest.Config)
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
@@ -955,7 +953,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
var layers []*Layer
layers = append(layers, manifest.Layers...)
layers = append(layers, &manifest.Config)
layers = append(layers, manifest.Config)
for _, layer := range layers {
if err := downloadBlob(
@@ -1043,30 +1041,6 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptio
return m, err
}
func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
config.RootFS = RootFS{
Type: "layers",
DiffIDs: layers,
}
configJSON, err := json.Marshal(config)
if err != nil {
return nil, err
}
digest, size := GetSHA256Digest(bytes.NewBuffer(configJSON))
layer := &LayerReader{
Layer: Layer{
MediaType: "application/vnd.docker.container.image.v1+json",
Digest: digest,
Size: size,
},
Reader: bytes.NewBuffer(configJSON),
}
return layer, nil
}
// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
func GetSHA256Digest(r io.Reader) (string, int64) {
h := sha256.New()

View File

@@ -1,23 +1,98 @@
package server
import (
"strings"
"testing"
"github.com/jmorganca/ollama/api"
)
func TestModelPrompt(t *testing.T) {
var m Model
req := api.GenerateRequest{
Template: "a{{ .Prompt }}b",
Prompt: "<h1>",
func TestChat(t *testing.T) {
tests := []struct {
name string
template string
msgs []api.Message
want string
wantErr string
}{
{
name: "Single Message",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
msgs: []api.Message{
{
Role: "system",
Content: "You are a Wizard.",
},
{
Role: "user",
Content: "What are the potion ingredients?",
},
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "Message History",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
msgs: []api.Message{
{
Role: "system",
Content: "You are a Wizard.",
},
{
Role: "user",
Content: "What are the potion ingredients?",
},
{
Role: "assistant",
Content: "sugar",
},
{
Role: "user",
Content: "Anything else?",
},
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
},
{
name: "Assistant Only",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
msgs: []api.Message{
{
Role: "assistant",
Content: "everything nice",
},
},
want: "[INST] [/INST]everything nice",
},
{
name: "Invalid Role",
msgs: []api.Message{
{
Role: "not-a-role",
Content: "howdy",
},
},
wantErr: "invalid role: not-a-role",
},
}
s, err := m.Prompt(req)
if err != nil {
t.Fatal(err)
}
want := "a<h1>b"
if s != want {
t.Errorf("got %q, want %q", s, want)
for _, tt := range tests {
m := Model{
Template: tt.template,
}
t.Run(tt.name, func(t *testing.T) {
got, err := m.ChatPrompt(tt.msgs)
if tt.wantErr != "" {
if err == nil {
t.Errorf("ChatPrompt() expected error, got nil")
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
}
}
if got != tt.want {
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
}
})
}
}

109
server/layers.go Normal file
View File

@@ -0,0 +1,109 @@
package server
import (
"crypto/sha256"
"fmt"
"io"
"os"
"runtime"
"strings"
"golang.org/x/exp/slices"
)
type Layers struct {
items []*Layer
}
func (ls *Layers) Add(layer *Layer) {
if layer.Size > 0 {
ls.items = append(ls.items, layer)
}
}
func (ls *Layers) Replace(layer *Layer) {
if layer.Size > 0 {
mediatype := layer.MediaType
layers := slices.DeleteFunc(ls.items, func(l *Layer) bool {
return l.MediaType == mediatype
})
ls.items = append(layers, layer)
}
}
type Layer struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int64 `json:"size"`
From string `json:"from,omitempty"`
tempFileName string
}
func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
blobs, err := GetBlobsPath("")
if err != nil {
return nil, err
}
delimiter := ":"
if runtime.GOOS == "windows" {
delimiter = "-"
}
pattern := strings.Join([]string{"sha256", "*-partial"}, delimiter)
temp, err := os.CreateTemp(blobs, pattern)
if err != nil {
return nil, err
}
defer temp.Close()
sha256sum := sha256.New()
n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
if err != nil {
return nil, err
}
return &Layer{
MediaType: mediatype,
Digest: fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)),
Size: n,
tempFileName: temp.Name(),
}, nil
}
func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
blob, err := GetBlobsPath(digest)
if err != nil {
return nil, err
}
fi, err := os.Stat(blob)
if err != nil {
return nil, err
}
return &Layer{
MediaType: mediatype,
Digest: digest,
Size: fi.Size(),
From: from,
}, nil
}
func (l *Layer) Commit() (bool, error) {
// always remove temp
defer os.Remove(l.tempFileName)
blob, err := GetBlobsPath(l.Digest)
if err != nil {
return false, err
}
if _, err := os.Stat(blob); err != nil {
return true, os.Rename(l.tempFileName, blob)
}
return false, nil
}

34
server/manifests.go Normal file
View File

@@ -0,0 +1,34 @@
package server
import (
"bytes"
"encoding/json"
"os"
"path/filepath"
)
func WriteManifest(name string, config *Layer, layers []*Layer) error {
manifest := ManifestV2{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: config,
Layers: layers,
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(manifest); err != nil {
return err
}
modelpath := ParseModelPath(name)
manifestPath, err := modelpath.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(manifestPath), 0755); err != nil {
return err
}
return os.WriteFile(manifestPath, b.Bytes(), 0644)
}

View File

@@ -2,7 +2,6 @@ package server
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
@@ -60,17 +59,26 @@ var loaded struct {
var defaultSessionDuration = 5 * time.Minute
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sessionDuration time.Duration) (*Model, error) {
model, err := GetModel(modelName)
if err != nil {
return nil, err
}
workDir := c.GetString("workDir")
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
log.Printf("could not load model options: %v", err)
return err
return nil, err
}
if err := opts.FromMap(reqOpts); err != nil {
return err
return nil, err
}
ctx := c.Request.Context()
// check if the loaded model is still running in a subprocess, in case something unexpected happened
if loaded.runner != nil {
if err := loaded.runner.Ping(ctx); err != nil {
@@ -97,7 +105,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
loaded.Options = nil
}
llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts)
llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
@@ -106,7 +114,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
}
return err
return nil, err
}
loaded.Model = model
@@ -140,7 +148,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
}
loaded.expireTimer.Reset(sessionDuration)
return nil
return model, nil
}
func GenerateHandler(c *gin.Context) {
@@ -173,88 +181,148 @@ func GenerateHandler(c *gin.Context) {
return
}
model, err := GetModel(req.Model)
sessionDuration := defaultSessionDuration
model, err := load(c, req.Model, req.Options, sessionDuration)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
switch {
case errors.As(err, &pErr):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
case errors.Is(err, api.ErrInvalidOpts):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
workDir := c.GetString("workDir")
// TODO: set this duration from the request if specified
sessionDuration := defaultSessionDuration
if err := load(c.Request.Context(), workDir, model, req.Options, sessionDuration); err != nil {
if errors.Is(err, api.ErrInvalidOpts) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
// an empty request loads the model
if req.Prompt == "" && req.Template == "" && req.System == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true})
return
}
checkpointLoaded := time.Now()
prompt := req.Prompt
if !req.Raw {
prompt, err = model.Prompt(req)
var prompt string
switch {
case req.Raw:
prompt = req.Prompt
case req.Prompt != "":
if req.Template != "" {
// override the default model template
model.Template = req.Template
}
var rebuild strings.Builder
if req.Context != nil {
// TODO: context is deprecated, at some point the context logic within this conditional should be removed
prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Remove leading spaces from prevCtx if present
prevCtx = strings.TrimPrefix(prevCtx, " ")
rebuild.WriteString(prevCtx)
}
p, err := model.Prompt(PromptVars{
System: req.System,
Prompt: req.Prompt,
First: len(req.Context) == 0,
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rebuild.WriteString(p)
prompt = rebuild.String()
}
ch := make(chan any)
var generated strings.Builder
go func() {
defer close(ch)
// an empty request loads the model
if req.Prompt == "" && req.Template == "" && req.System == "" {
ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}
return
}
fn := func(r api.GenerateResponse) {
fn := func(r llm.PredictResult) {
// Update model expiration
loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration)
r.Model = req.Model
r.CreatedAt = time.Now().UTC()
if r.Done {
r.TotalDuration = time.Since(checkpointStart)
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
// Build up the full response
if _, err := generated.WriteString(r.Content); err != nil {
ch <- gin.H{"error": err.Error()}
return
}
if req.Raw {
// in raw mode the client must manage history on their own
r.Context = nil
resp := api.GenerateResponse{
Model: req.Model,
CreatedAt: r.CreatedAt,
Done: r.Done,
Response: r.Content,
Metrics: api.Metrics{
TotalDuration: r.TotalDuration,
LoadDuration: r.LoadDuration,
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
ch <- r
if r.Done && !req.Raw {
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String())
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp.Context = embd
}
ch <- resp
}
if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, req.Format, fn); err != nil {
// Start prediction
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
CheckpointStart: checkpointStart,
CheckpointLoaded: checkpointLoaded,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
var response api.GenerateResponse
generated := ""
// Accumulate responses into the final response
var final api.GenerateResponse
var sb strings.Builder
for resp := range ch {
if r, ok := resp.(api.GenerateResponse); ok {
generated += r.Response
response = r
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
switch r := resp.(type) {
case api.GenerateResponse:
sb.WriteString(r.Response)
final = r
case gin.H:
if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
return
}
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
return
}
}
response.Response = generated
c.JSON(http.StatusOK, response)
final.Response = sb.String()
c.JSON(http.StatusOK, final)
return
}
@@ -281,15 +349,18 @@ func EmbeddingHandler(c *gin.Context) {
return
}
model, err := GetModel(req.Model)
sessionDuration := defaultSessionDuration
_, err = load(c, req.Model, req.Options, sessionDuration)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
workDir := c.GetString("workDir")
if err := load(c.Request.Context(), workDir, model, req.Options, 5*time.Minute); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
var pErr *fs.PathError
switch {
case errors.As(err, &pErr):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
case errors.Is(err, api.ErrInvalidOpts):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
@@ -676,37 +747,18 @@ func HeadBlobHandler(c *gin.Context) {
}
func CreateBlobHandler(c *gin.Context) {
targetPath, err := GetBlobsPath(c.Param("digest"))
layer, err := NewLayer(c.Request.Body, "")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
hash := sha256.New()
temp, err := os.CreateTemp(filepath.Dir(targetPath), c.Param("digest")+"-")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer temp.Close()
defer os.Remove(temp.Name())
if _, err := io.Copy(temp, io.TeeReader(c.Request.Body, hash)); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
if layer.Digest != c.Param("digest") {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("digest mismatch, expected %q, got %q", c.Param("digest"), layer.Digest)})
return
}
if fmt.Sprintf("sha256:%x", hash.Sum(nil)) != c.Param("digest") {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "digest does not match body"})
return
}
if err := temp.Close(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := os.Rename(temp.Name(), targetPath); err != nil {
if _, err := layer.Commit(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -767,6 +819,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
r.POST("/api/pull", PullModelHandler)
r.POST("/api/generate", GenerateHandler)
r.POST("/api/chat", ChatHandler)
r.POST("/api/embeddings", EmbeddingHandler)
r.POST("/api/create", CreateModelHandler)
r.POST("/api/push", PushModelHandler)
@@ -782,6 +835,9 @@ func Serve(ln net.Listener, allowOrigins []string) error {
})
r.Handle(method, "/api/tags", ListModelsHandler)
r.Handle(method, "/api/version", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"version": version.Version})
})
}
log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
@@ -804,7 +860,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
if runtime.GOOS == "linux" {
// check compatibility to log warnings
if _, err := llm.CheckVRAM(); err != nil {
log.Printf(err.Error())
log.Print(err.Error())
}
}
@@ -860,3 +916,136 @@ func streamResponse(c *gin.Context, ch chan any) {
return true
})
}
func ChatHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
checkpointStart := time.Now()
var req api.ChatRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// validate the request
switch {
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
}
sessionDuration := defaultSessionDuration
model, err := load(c, req.Model, req.Options, sessionDuration)
if err != nil {
var pErr *fs.PathError
switch {
case errors.As(err, &pErr):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
case errors.Is(err, api.ErrInvalidOpts):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
// an empty request loads the model
if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true})
return
}
checkpointLoaded := time.Now()
prompt, err := model.ChatPrompt(req.Messages)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ch := make(chan any)
go func() {
defer close(ch)
fn := func(r llm.PredictResult) {
// Update model expiration
loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration)
resp := api.ChatResponse{
Model: req.Model,
CreatedAt: r.CreatedAt,
Done: r.Done,
Metrics: api.Metrics{
TotalDuration: r.TotalDuration,
LoadDuration: r.LoadDuration,
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
if !r.Done {
resp.Message = &api.Message{Role: "assistant", Content: r.Content}
}
ch <- resp
}
// Start prediction
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
CheckpointStart: checkpointStart,
CheckpointLoaded: checkpointLoaded,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
// Accumulate responses into the final response
var final api.ChatResponse
var sb strings.Builder
for resp := range ch {
switch r := resp.(type) {
case api.ChatResponse:
if r.Message != nil {
sb.WriteString(r.Message.Content)
}
final = r
case gin.H:
if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
return
}
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
return
}
}
final.Message = &api.Message{Role: "assistant", Content: sb.String()}
c.JSON(http.StatusOK, final)
return
}
streamResponse(c, ch)
}