Compare commits

...

11 Commits

Author SHA1 Message Date
Bruce MacDonald
a5bc4b7c17 Update images_test.go 2024-12-03 15:03:08 -08:00
Bruce MacDonald
1be080403d server: feedback before failing push on uppercase
When a username or model name is uppercase the registry will reject the
push. This is done for file-system compatibility. If we rely on the registry
error on push the message returned is 'file not found', which does not
convey why the push actually failed.
2024-12-03 14:40:23 -08:00
Tigran
55c3efa900 docs: remove extra quote in modelfile.md (#7908) 2024-12-02 09:28:56 -08:00
David Mayboroda
1aedffad93 readme: add minima to community integrations (#7906) 2024-12-02 01:14:47 -08:00
Jeffrey Morgan
ff6c2d6dc8 cmd: don't rely on reading repo file for test (#7898) 2024-11-30 14:12:53 -08:00
Jeffrey Morgan
d543b282a7 server: add warning message for deprecated context field (#7878) 2024-11-30 14:05:50 -08:00
Parth Sareen
5f8051180e Enable index tracking for tools - openai api support (#7888) 2024-11-29 20:00:09 -08:00
Jeffrey Morgan
39e29ae5dd llama: fix typo and formatting in readme (#7876) 2024-11-28 17:27:11 -08:00
TheCookingSenpai
30a9f063c9 readme: add SpaceLlama, YouLama, and DualMind to community integrations (#7216) 2024-11-28 15:16:27 -08:00
Parth Sareen
ce7455a8e1 api: enable tool streaming (#7836) 2024-11-27 13:40:57 -08:00
ItzCrazyKns
e3936d4fb3 Support Multiple LoRa Adapters (#7667)
Closes #7627
2024-11-27 11:00:04 -08:00
15 changed files with 476 additions and 47 deletions

View File

@@ -346,6 +346,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page) - [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.) - [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama) - [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
- [SpaceLlama](https://github.com/tcsenpai/spacellama) (Firefox and Chrome extension to quickly summarize web pages with ollama in a sidebar)
- [YouLama](https://github.com/tcsenpai/youlama) (Webapp to quickly summarize any YouTube video, supporting Invidious as well)
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol) - [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app) - [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings) - [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings)
@@ -356,6 +359,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Nosia](https://github.com/nosia-ai/nosia) (Easy to install and use RAG platform based on Ollama) - [Nosia](https://github.com/nosia-ai/nosia) (Easy to install and use RAG platform based on Ollama)
- [Witsy](https://github.com/nbonamy/witsy) (An AI Desktop application avaiable for Mac/Windows/Linux) - [Witsy](https://github.com/nbonamy/witsy) (An AI Desktop application avaiable for Mac/Windows/Linux)
- [Abbey](https://github.com/US-Artificial-Intelligence/abbey) (A configurable AI interface server with notebooks, document storage, and YouTube support) - [Abbey](https://github.com/US-Artificial-Intelligence/abbey) (A configurable AI interface server with notebooks, document storage, and YouTube support)
- [Minima](https://github.com/dmayboroda/minima) (RAG with on-premises or fully local workflow)
### Cloud ### Cloud

View File

@@ -146,6 +146,7 @@ type ToolCall struct {
} }
type ToolCallFunction struct { type ToolCallFunction struct {
Index int `json:"index,omitempty"`
Name string `json:"name"` Name string `json:"name"`
Arguments ToolCallFunctionArguments `json:"arguments"` Arguments ToolCallFunctionArguments `json:"arguments"`
} }

View File

@@ -8,7 +8,6 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath"
"strings" "strings"
"testing" "testing"
@@ -180,18 +179,14 @@ Weigh anchor!
t.Run("license", func(t *testing.T) { t.Run("license", func(t *testing.T) {
var b bytes.Buffer var b bytes.Buffer
license, err := os.ReadFile(filepath.Join("..", "LICENSE")) license := "MIT License\nCopyright (c) Ollama\n"
if err != nil {
t.Fatal(err)
}
if err := showInfo(&api.ShowResponse{ if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{ Details: api.ModelDetails{
Family: "test", Family: "test",
ParameterSize: "7B", ParameterSize: "7B",
QuantizationLevel: "FP16", QuantizationLevel: "FP16",
}, },
License: string(license), License: license,
}, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -49,10 +49,10 @@ Advanced parameters (optional):
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `system`: system message to (overrides what is defined in the `Modelfile`) - `system`: system message to (overrides what is defined in the `Modelfile`)
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`) - `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API - `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
- `context` (deprecated): the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
#### JSON mode #### JSON mode

View File

@@ -63,7 +63,7 @@ SYSTEM You are Mario from super mario bros, acting as an assistant.
To use this: To use this:
1. Save it as a file (e.g. `Modelfile`) 1. Save it as a file (e.g. `Modelfile`)
2. `ollama create choose-a-model-name -f <location of the file e.g. ./Modelfile>'` 2. `ollama create choose-a-model-name -f <location of the file e.g. ./Modelfile>`
3. `ollama run choose-a-model-name` 3. `ollama run choose-a-model-name`
4. Start using the model! 4. Start using the model!

View File

@@ -105,7 +105,7 @@ make apply-patches
**Pin to new base commit** **Pin to new base commit**
To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring.env` To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring`
#### Applying patches #### Applying patches

View File

@@ -833,10 +833,21 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) {
} }
} }
type multiLPath []string
func (m *multiLPath) Set(value string) error {
*m = append(*m, value)
return nil
}
func (m *multiLPath) String() string {
return strings.Join(*m, ", ")
}
func (s *Server) loadModel( func (s *Server) loadModel(
params llama.ModelParams, params llama.ModelParams,
mpath string, mpath string,
lpath string, lpath multiLPath,
ppath string, ppath string,
kvSize int, kvSize int,
flashAttention bool, flashAttention bool,
@@ -857,12 +868,14 @@ func (s *Server) loadModel(
panic(err) panic(err)
} }
if lpath != "" { if lpath.String() != "" {
err := s.model.ApplyLoraFromFile(s.lc, lpath, 1.0, threads) for _, path := range lpath {
err := s.model.ApplyLoraFromFile(s.lc, path, 1.0, threads)
if err != nil { if err != nil {
panic(err) panic(err)
} }
} }
}
if ppath != "" { if ppath != "" {
var err error var err error
@@ -890,7 +903,6 @@ func main() {
mainGpu := flag.Int("main-gpu", 0, "Main GPU") mainGpu := flag.Int("main-gpu", 0, "Main GPU")
flashAttention := flag.Bool("flash-attn", false, "Enable flash attention") flashAttention := flag.Bool("flash-attn", false, "Enable flash attention")
kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size") kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size")
lpath := flag.String("lora", "", "Path to lora layer file")
port := flag.Int("port", 8080, "Port to expose the server on") port := flag.Int("port", 8080, "Port to expose the server on")
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation") threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
verbose := flag.Bool("verbose", false, "verbose output (default: disabled)") verbose := flag.Bool("verbose", false, "verbose output (default: disabled)")
@@ -900,6 +912,9 @@ func main() {
multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users") multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
requirements := flag.Bool("requirements", false, "print json requirement information") requirements := flag.Bool("requirements", false, "print json requirement information")
var lpaths multiLPath
flag.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)")
flag.Parse() flag.Parse()
if *requirements { if *requirements {
printRequirements(os.Stdout) printRequirements(os.Stdout)
@@ -946,7 +961,7 @@ func main() {
params := llama.ModelParams{ params := llama.ModelParams{
NumGpuLayers: *nGpuLayers, NumGpuLayers: *nGpuLayers,
MainGpu: *mainGpu, MainGpu: *mainGpu,
UseMmap: !*noMmap && *lpath == "", UseMmap: !*noMmap && lpaths.String() == "",
UseMlock: *mlock, UseMlock: *mlock,
TensorSplit: tensorSplitFloats, TensorSplit: tensorSplitFloats,
Progress: func(progress float32) { Progress: func(progress float32) {
@@ -955,7 +970,7 @@ func main() {
} }
server.ready.Add(1) server.ready.Add(1)
go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache) go server.loadModel(params, *mpath, lpaths, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache)
server.cond = sync.NewCond(&server.mu) server.cond = sync.NewCond(&server.mu)

View File

@@ -144,10 +144,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
// Loop through potential servers // Loop through potential servers
finalErr := errors.New("no suitable llama servers found") finalErr := errors.New("no suitable llama servers found")
if len(adapters) > 1 {
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
}
rDir, err := runners.Refresh(build.EmbedFS) rDir, err := runners.Refresh(build.EmbedFS)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -201,8 +197,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
} }
if len(adapters) > 0 { if len(adapters) > 0 {
// TODO: applying multiple adapters is not supported by the llama.cpp server yet for _, adapter := range adapters {
params = append(params, "--lora", adapters[0]) params = append(params, "--lora", adapter)
}
} }
if len(projectors) > 0 { if len(projectors) > 0 {

View File

@@ -140,6 +140,7 @@ type CompletionChunk struct {
type ToolCall struct { type ToolCall struct {
ID string `json:"id"` ID string `json:"id"`
Index int `json:"index"`
Type string `json:"type"` Type string `json:"type"`
Function struct { Function struct {
Name string `json:"name"` Name string `json:"name"`
@@ -200,12 +201,13 @@ func toolCallId() string {
return "call_" + strings.ToLower(string(b)) return "call_" + strings.ToLower(string(b))
} }
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { func toToolCalls(tc []api.ToolCall) []ToolCall {
toolCalls := make([]ToolCall, len(r.Message.ToolCalls)) toolCalls := make([]ToolCall, len(tc))
for i, tc := range r.Message.ToolCalls { for i, tc := range tc {
toolCalls[i].ID = toolCallId() toolCalls[i].ID = toolCallId()
toolCalls[i].Type = "function" toolCalls[i].Type = "function"
toolCalls[i].Function.Name = tc.Function.Name toolCalls[i].Function.Name = tc.Function.Name
toolCalls[i].Index = tc.Function.Index
args, err := json.Marshal(tc.Function.Arguments) args, err := json.Marshal(tc.Function.Arguments)
if err != nil { if err != nil {
@@ -215,7 +217,11 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls[i].Function.Arguments = string(args) toolCalls[i].Function.Arguments = string(args)
} }
return toolCalls
}
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := toToolCalls(r.Message.ToolCalls)
return ChatCompletion{ return ChatCompletion{
Id: id, Id: id,
Object: "chat.completion", Object: "chat.completion",
@@ -244,6 +250,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
} }
func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
toolCalls := toToolCalls(r.Message.ToolCalls)
return ChatCompletionChunk{ return ChatCompletionChunk{
Id: id, Id: id,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
@@ -252,7 +259,7 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
SystemFingerprint: "fp_ollama", SystemFingerprint: "fp_ollama",
Choices: []ChunkChoice{{ Choices: []ChunkChoice{{
Index: 0, Index: 0,
Delta: Message{Role: "assistant", Content: r.Message.Content}, Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
FinishReason: func(reason string) *string { FinishReason: func(reason string) *string {
if len(reason) > 0 { if len(reason) > 0 {
return &reason return &reason

View File

@@ -195,7 +195,86 @@ func TestChatMiddleware(t *testing.T) {
Stream: &False, Stream: &False,
}, },
}, },
{
name: "chat handler with streaming tools",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris?"}
],
"stream": true,
"tools": [{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"required": ["location"],
"properties": {
"location": {
"type": "string",
"description": "The city and state"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
}
}
}
}]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris?",
},
},
Tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Required: []string{"location"},
Properties: map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}{
"location": {
Type: "string",
Description: "The city and state",
},
"unit": {
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
},
},
},
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &True,
},
},
{ {
name: "chat handler error forwarding", name: "chat handler error forwarding",
body: `{ body: `{

View File

@@ -802,6 +802,12 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
if mp.ProtocolScheme == "http" && !regOpts.Insecure { if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return errors.New("insecure protocol http") return errors.New("insecure protocol http")
} }
if mp.Namespace != strings.ToLower(mp.Namespace) {
return fmt.Errorf("namespace must be lowercase, but is %s", mp.Namespace)
}
if mp.Repository != strings.ToLower(mp.Repository) {
return fmt.Errorf("model name must be lowercase, but is %s", mp.Repository)
}
manifest, _, err := GetManifest(mp) manifest, _, err := GetManifest(mp)
if err != nil { if err != nil {

50
server/images_test.go Normal file
View File

@@ -0,0 +1,50 @@
package server
import (
"context"
"strings"
"testing"
"github.com/ollama/ollama/api"
)
func TestPushModel(t *testing.T) {
noOpProgress := func(resp api.ProgressResponse) {}
tests := []struct {
modelStr string
regOpts *registryOptions
wantErr string
}{
{
modelStr: "http://example.com/namespace/repo:tag",
regOpts: &registryOptions{Insecure: false},
wantErr: "insecure protocol http",
},
{
modelStr: "docker://Example/repo:tag",
regOpts: &registryOptions{},
wantErr: "namespace must be lowercase, but is Example",
},
{
modelStr: "docker://example/Repo:tag",
regOpts: &registryOptions{},
wantErr: "model name must be lowercase, but is Repo",
},
}
for _, tt := range tests {
t.Run(tt.modelStr, func(t *testing.T) {
err := PushModel(context.Background(), tt.modelStr, tt.regOpts, noOpProgress)
if tt.wantErr != "" {
if err == nil {
t.Errorf("PushModel() error = %v, wantErr %v", err, tt.wantErr)
} else if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("PushModel() error = %v, wantErr %v", err, tt.wantErr)
}
return
}
})
}
}

View File

@@ -39,6 +39,7 @@ func TestExecuteWithTools(t *testing.T) {
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true}, The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false},
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: {"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},

View File

@@ -251,6 +251,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var b bytes.Buffer var b bytes.Buffer
if req.Context != nil { if req.Context != nil {
slog.Warn("the context field is deprecated and will be removed in a future version of Ollama")
s, err := r.Detokenize(c.Request.Context(), req.Context) s, err := r.Detokenize(c.Request.Context(), req.Context)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -1458,6 +1459,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
if err != nil { if err != nil {
slog.Error("chat prompt error", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
@@ -1467,6 +1469,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
var sb strings.Builder
var toolCallIndex int = 0
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
@@ -1492,7 +1496,37 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.LoadDuration = checkpointLoaded.Sub(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} }
// TODO: tool call checking and filtering should be moved outside of this callback once streaming
// however this was a simple change for now without reworking streaming logic of this (and other)
// handlers
if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 {
ch <- res ch <- res
return
}
// Streaming tool calls:
// If tools are recognized, use a flag to track the sending of a tool downstream
// This ensures that content is cleared from the message on the last chunk sent
sb.WriteString(r.Content)
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
res.Message.ToolCalls = toolCalls
for i := range toolCalls {
toolCalls[i].Function.Index = toolCallIndex
toolCallIndex++
}
res.Message.Content = ""
sb.Reset()
ch <- res
return
}
if r.Done {
// Send any remaining content if no tool calls were detected
if toolCallIndex == 0 {
res.Message.Content = sb.String()
}
ch <- res
}
}); err != nil { }); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }

View File

@@ -8,6 +8,7 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@@ -25,10 +26,14 @@ type mockRunner struct {
// CompletionRequest is only valid until the next call to Completion // CompletionRequest is only valid until the next call to Completion
llm.CompletionRequest llm.CompletionRequest
llm.CompletionResponse llm.CompletionResponse
CompletionFn func(context.Context, llm.CompletionRequest, func(llm.CompletionResponse)) error
} }
func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
m.CompletionRequest = r m.CompletionRequest = r
if m.CompletionFn != nil {
return m.CompletionFn(ctx, r, fn)
}
fn(m.CompletionResponse) fn(m.CompletionResponse)
return nil return nil
} }
@@ -88,9 +93,14 @@ func TestGenerateChat(t *testing.T) {
Model: "test", Model: "test",
Modelfile: fmt.Sprintf(`FROM %s Modelfile: fmt.Sprintf(`FROM %s
TEMPLATE """ TEMPLATE """
{{- if .System }}System: {{ .System }} {{ end }} {{- if .Tools }}
{{- if .Prompt }}User: {{ .Prompt }} {{ end }} {{ .Tools }}
{{- if .Response }}Assistant: {{ .Response }} {{ end }}""" {{ end }}
{{- range .Messages }}
{{- .Role }}: {{ .Content }}
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}
{{ end }}"""
`, createBinFile(t, llm.KV{ `, createBinFile(t, llm.KV{
"general.architecture": "llama", "general.architecture": "llama",
"llama.block_count": uint32(1), "llama.block_count": uint32(1),
@@ -263,7 +273,7 @@ func TestGenerateChat(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" { if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
@@ -292,7 +302,7 @@ func TestGenerateChat(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" { if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
@@ -314,7 +324,7 @@ func TestGenerateChat(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" { if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
@@ -337,12 +347,242 @@ func TestGenerateChat(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" { if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!") checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
}) })
t.Run("messages with tools (non-streaming)", func(t *testing.T) {
if w.Code != http.StatusOK {
t.Fatalf("failed to create test-system model: %d", w.Code)
}
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Required: []string{"location"},
Properties: map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}{
"location": {
Type: "string",
Description: "The city and state",
},
"unit": {
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
},
},
},
},
},
}
mock.CompletionResponse = llm.CompletionResponse{
Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
Done: true,
DoneReason: "done",
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
}
streamRequest := true
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "user", Content: "What's the weather in Seattle?"},
},
Tools: tools,
Stream: &streamRequest,
})
if w.Code != http.StatusOK {
var errResp struct {
Error string `json:"error"`
}
if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil {
t.Logf("Failed to decode error response: %v", err)
} else {
t.Logf("Error response: %s", errResp.Error)
}
}
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var resp api.ChatResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.Message.ToolCalls == nil {
t.Error("expected tool calls, got nil")
}
expectedToolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Seattle, WA",
"unit": "celsius",
},
},
}
if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" {
t.Errorf("tool call mismatch (-got +want):\n%s", diff)
}
})
t.Run("messages with tools (streaming)", func(t *testing.T) {
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Required: []string{"location"},
Properties: map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}{
"location": {
Type: "string",
Description: "The city and state",
},
"unit": {
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
},
},
},
},
},
}
// Simulate streaming response with multiple chunks
var wg sync.WaitGroup
wg.Add(1)
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
defer wg.Done()
// Send chunks with small delays to simulate streaming
responses := []llm.CompletionResponse{
{
Content: `{"name":"get_`,
Done: false,
PromptEvalCount: 1,
PromptEvalDuration: 1,
},
{
Content: `weather","arguments":{"location":"Seattle`,
Done: false,
PromptEvalCount: 2,
PromptEvalDuration: 1,
},
{
Content: `, WA","unit":"celsius"}}`,
Done: true,
DoneReason: "tool_call",
PromptEvalCount: 3,
PromptEvalDuration: 1,
},
}
for _, resp := range responses {
select {
case <-ctx.Done():
return ctx.Err()
default:
fn(resp)
time.Sleep(10 * time.Millisecond) // Small delay between chunks
}
}
return nil
}
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "user", Content: "What's the weather in Seattle?"},
},
Tools: tools,
Stream: &stream,
})
wg.Wait()
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
// Read and validate the streamed responses
decoder := json.NewDecoder(w.Body)
var finalToolCall api.ToolCall
for {
var resp api.ChatResponse
if err := decoder.Decode(&resp); err == io.EOF {
break
} else if err != nil {
t.Fatal(err)
}
if resp.Done {
if len(resp.Message.ToolCalls) != 1 {
t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls))
}
finalToolCall = resp.Message.ToolCalls[0]
}
}
expectedToolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Seattle, WA",
"unit": "celsius",
},
},
}
if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
}
})
} }
func TestGenerate(t *testing.T) { func TestGenerate(t *testing.T) {