Compare commits
66 Commits
jyan/progr
...
v0.2.6
Author | SHA1 | Date | |
---|---|---|---|
![]() |
319fb1ce03 | ||
![]() |
b255445557 | ||
![]() |
b23424bb3c | ||
![]() |
5fd6988126 | ||
![]() |
5b82960df8 | ||
![]() |
cc9a252d8c | ||
![]() |
d281a6e603 | ||
![]() |
154f6f45d4 | ||
![]() |
0d41623b52 | ||
![]() |
c279f96371 | ||
![]() |
499e87c9ba | ||
![]() |
cd0853f2d5 | ||
![]() |
d290e87513 | ||
![]() |
97c20ede33 | ||
![]() |
5a83f79afd | ||
![]() |
987dbab0b0 | ||
![]() |
a8388beb94 | ||
![]() |
5afbb60fc4 | ||
![]() |
4cb5d7decc | ||
![]() |
8eac50dd4f | ||
![]() |
4a565cbf94 | ||
![]() |
64039df6d7 | ||
![]() |
7ac6d462ec | ||
![]() |
ef5136a745 | ||
![]() |
8288ec8824 | ||
![]() |
d02bbebb11 | ||
![]() |
224337b32f | ||
![]() |
9e35d9bbee | ||
![]() |
b9f5e16c80 | ||
![]() |
e9f7f36029 | ||
![]() |
057d31861e | ||
![]() |
f7ee012300 | ||
![]() |
1ed0aa8fea | ||
![]() |
ef98803d63 | ||
![]() |
02fea420e5 | ||
![]() |
22c5451fc2 | ||
![]() |
23ebbaa46e | ||
![]() |
9ac0a7a50b | ||
![]() |
e5c65a85df | ||
![]() |
33627331a3 | ||
![]() |
36c87c433b | ||
![]() |
179737feb7 | ||
![]() |
47353f5ee4 | ||
![]() |
10e768826c | ||
![]() |
5056bb9c01 | ||
![]() |
c4cf8ad559 | ||
![]() |
57ec6901eb | ||
![]() |
e64f9ebb44 | ||
![]() |
791650ddef | ||
![]() |
efbf41ed81 | ||
![]() |
cf15589851 | ||
![]() |
19753c18c0 | ||
![]() |
41be28096a | ||
![]() |
37a570f962 | ||
![]() |
5a739ff4cb | ||
![]() |
4e262eb2a8 | ||
![]() |
4cfcbc328f | ||
![]() |
79292ff3e0 | ||
![]() |
8ea500441d | ||
![]() |
b50c818623 | ||
![]() |
b99e750b62 | ||
![]() |
1f50356e8e | ||
![]() |
22c81f62ec | ||
![]() |
f6f759fc5f | ||
![]() |
b44320db13 | ||
![]() |
784bf88b0d |
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
@@ -147,7 +147,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
write-host "downloading AMD HIP Installer"
|
write-host "downloading AMD HIP Installer"
|
||||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||||
write-host "Installing AMD HIP"
|
write-host "Installing AMD HIP"
|
||||||
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
||||||
write-host "Completed AMD HIP"
|
write-host "Completed AMD HIP"
|
||||||
|
4
.github/workflows/test.yaml
vendored
4
.github/workflows/test.yaml
vendored
@@ -126,7 +126,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
rocm-version:
|
rocm-version:
|
||||||
- '6.1.1'
|
- '6.1.2'
|
||||||
runs-on: linux
|
runs-on: linux
|
||||||
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
|
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
|
||||||
steps:
|
steps:
|
||||||
@@ -169,7 +169,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
write-host "downloading AMD HIP Installer"
|
write-host "downloading AMD HIP Installer"
|
||||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||||
write-host "Installing AMD HIP"
|
write-host "Installing AMD HIP"
|
||||||
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
||||||
write-host "Completed AMD HIP"
|
write-host "Completed AMD HIP"
|
||||||
|
@@ -2,7 +2,7 @@ ARG GOLANG_VERSION=1.22.1
|
|||||||
ARG CMAKE_VERSION=3.22.1
|
ARG CMAKE_VERSION=3.22.1
|
||||||
# this CUDA_VERSION corresponds with the one specified in docs/gpu.md
|
# this CUDA_VERSION corresponds with the one specified in docs/gpu.md
|
||||||
ARG CUDA_VERSION=11.3.1
|
ARG CUDA_VERSION=11.3.1
|
||||||
ARG ROCM_VERSION=6.1.1
|
ARG ROCM_VERSION=6.1.2
|
||||||
|
|
||||||
# Copy the minimal context we need to run the generate scripts
|
# Copy the minimal context we need to run the generate scripts
|
||||||
FROM scratch AS llm-code
|
FROM scratch AS llm-code
|
||||||
|
@@ -293,6 +293,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
||||||
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
||||||
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
|
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
|
||||||
|
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
|
||||||
|
- [AI Studio](https://github.com/MindWorkAI/AI-Studio)
|
||||||
|
- [Sidellama](https://github.com/gyopak/sidellama) (browser-based LLM client)
|
||||||
|
|
||||||
### Terminal
|
### Terminal
|
||||||
|
|
||||||
|
@@ -347,7 +347,16 @@ func (c *Client) Heartbeat(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Embeddings generates embeddings from a model.
|
// Embed generates embeddings from a model.
|
||||||
|
func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
|
||||||
|
var resp EmbedResponse
|
||||||
|
if err := c.do(ctx, http.MethodPost, "/api/embed", req, &resp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Embeddings generates an embedding from a model.
|
||||||
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
|
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
|
||||||
var resp EmbeddingResponse
|
var resp EmbeddingResponse
|
||||||
if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
|
if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
|
||||||
|
102
api/types.go
102
api/types.go
@@ -47,6 +47,9 @@ type GenerateRequest struct {
|
|||||||
// Prompt is the textual prompt to send to the model.
|
// Prompt is the textual prompt to send to the model.
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
|
||||||
|
// Suffix is the text that comes after the inserted text.
|
||||||
|
Suffix string `json:"suffix"`
|
||||||
|
|
||||||
// System overrides the model's default system message/prompt.
|
// System overrides the model's default system message/prompt.
|
||||||
System string `json:"system"`
|
System string `json:"system"`
|
||||||
|
|
||||||
@@ -97,17 +100,80 @@ type ChatRequest struct {
|
|||||||
// followin the request.
|
// followin the request.
|
||||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
|
|
||||||
|
// Tools is an optional list of tools the model has access to.
|
||||||
|
Tools `json:"tools,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Tools []Tool
|
||||||
|
|
||||||
|
func (t Tools) String() string {
|
||||||
|
bts, _ := json.Marshal(t)
|
||||||
|
return string(bts)
|
||||||
|
}
|
||||||
|
|
||||||
// Message is a single message in a chat sequence. The message contains the
|
// Message is a single message in a chat sequence. The message contains the
|
||||||
// role ("system", "user", or "assistant"), the content and an optional list
|
// role ("system", "user", or "assistant"), the content and an optional list
|
||||||
// of images.
|
// of images.
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content,omitempty"`
|
||||||
Images []ImageData `json:"images,omitempty"`
|
Images []ImageData `json:"images,omitempty"`
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||||
|
type Alias Message
|
||||||
|
var a Alias
|
||||||
|
if err := json.Unmarshal(b, &a); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*m = Message(a)
|
||||||
|
m.Role = strings.ToLower(m.Role)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
Function ToolCallFunction `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCallFunction struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments ToolCallFunctionArguments `json:"arguments"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCallFunctionArguments map[string]any
|
||||||
|
|
||||||
|
func (t *ToolCallFunctionArguments) String() string {
|
||||||
|
bts, _ := json.Marshal(t)
|
||||||
|
return string(bts)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function ToolFunction `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolFunction struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Parameters struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Required []string `json:"required"`
|
||||||
|
Properties map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
} `json:"properties"`
|
||||||
|
} `json:"parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ToolFunction) String() string {
|
||||||
|
bts, _ := json.Marshal(t)
|
||||||
|
return string(bts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||||
@@ -173,6 +239,30 @@ type Runner struct {
|
|||||||
NumThread int `json:"num_thread,omitempty"`
|
NumThread int `json:"num_thread,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EmbedRequest is the request passed to [Client.Embed].
|
||||||
|
type EmbedRequest struct {
|
||||||
|
// Model is the model name.
|
||||||
|
Model string `json:"model"`
|
||||||
|
|
||||||
|
// Input is the input to embed.
|
||||||
|
Input any `json:"input"`
|
||||||
|
|
||||||
|
// KeepAlive controls how long the model will stay loaded in memory following
|
||||||
|
// this request.
|
||||||
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
|
|
||||||
|
Truncate *bool `json:"truncate,omitempty"`
|
||||||
|
|
||||||
|
// Options lists model-specific options.
|
||||||
|
Options map[string]interface{} `json:"options"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbedResponse is the response from [Client.Embed].
|
||||||
|
type EmbedResponse struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Embeddings [][]float32 `json:"embeddings"`
|
||||||
|
}
|
||||||
|
|
||||||
// EmbeddingRequest is the request passed to [Client.Embeddings].
|
// EmbeddingRequest is the request passed to [Client.Embeddings].
|
||||||
type EmbeddingRequest struct {
|
type EmbeddingRequest struct {
|
||||||
// Model is the model name.
|
// Model is the model name.
|
||||||
@@ -219,8 +309,10 @@ type DeleteRequest struct {
|
|||||||
|
|
||||||
// ShowRequest is the request passed to [Client.Show].
|
// ShowRequest is the request passed to [Client.Show].
|
||||||
type ShowRequest struct {
|
type ShowRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
System string `json:"system"`
|
System string `json:"system"`
|
||||||
|
|
||||||
|
// Template is deprecated
|
||||||
Template string `json:"template"`
|
Template string `json:"template"`
|
||||||
Verbose bool `json:"verbose"`
|
Verbose bool `json:"verbose"`
|
||||||
|
|
||||||
|
@@ -208,3 +208,26 @@ func TestUseMmapFormatParams(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMessage_UnmarshalJSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{`{"role": "USER", "content": "Hello!"}`, "user"},
|
||||||
|
{`{"role": "System", "content": "Initialization complete."}`, "system"},
|
||||||
|
{`{"role": "assistant", "content": "How can I help you?"}`, "assistant"},
|
||||||
|
{`{"role": "TOOl", "content": "Access granted."}`, "tool"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
var msg Message
|
||||||
|
if err := json.Unmarshal([]byte(test.input), &msg); err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.Role != test.expected {
|
||||||
|
t.Errorf("role not lowercased: got %v, expected %v", msg.Role, test.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -127,6 +127,10 @@ Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\models"
|
|||||||
Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history"
|
Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history"
|
||||||
; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved
|
; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved
|
||||||
|
|
||||||
|
[InstallDelete]
|
||||||
|
Type: filesandordirs; Name: "{%TEMP}\ollama*"
|
||||||
|
Type: filesandordirs; Name: "{%LOCALAPPDATA}\Programs\Ollama"
|
||||||
|
|
||||||
[Messages]
|
[Messages]
|
||||||
WizardReady=Ollama Windows Preview
|
WizardReady=Ollama Windows Preview
|
||||||
ReadyLabel1=%nLet's get you up and running with your own large language models.
|
ReadyLabel1=%nLet's get you up and running with your own large language models.
|
||||||
|
@@ -843,7 +843,6 @@ type runOptions struct {
|
|||||||
WordWrap bool
|
WordWrap bool
|
||||||
Format string
|
Format string
|
||||||
System string
|
System string
|
||||||
Template string
|
|
||||||
Images []api.ImageData
|
Images []api.ImageData
|
||||||
Options map[string]interface{}
|
Options map[string]interface{}
|
||||||
MultiModal bool
|
MultiModal bool
|
||||||
@@ -1037,7 +1036,6 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
|||||||
Images: opts.Images,
|
Images: opts.Images,
|
||||||
Format: opts.Format,
|
Format: opts.Format,
|
||||||
System: opts.System,
|
System: opts.System,
|
||||||
Template: opts.Template,
|
|
||||||
Options: opts.Options,
|
Options: opts.Options,
|
||||||
KeepAlive: opts.KeepAlive,
|
KeepAlive: opts.KeepAlive,
|
||||||
}
|
}
|
||||||
|
@@ -27,7 +27,6 @@ const (
|
|||||||
MultilineNone MultilineState = iota
|
MultilineNone MultilineState = iota
|
||||||
MultilinePrompt
|
MultilinePrompt
|
||||||
MultilineSystem
|
MultilineSystem
|
||||||
MultilineTemplate
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func loadModel(cmd *cobra.Command, opts *runOptions) error {
|
func loadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||||
@@ -94,7 +93,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||||
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
|
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
|
||||||
fmt.Fprintln(os.Stderr, " /set system <string> Set system message")
|
fmt.Fprintln(os.Stderr, " /set system <string> Set system message")
|
||||||
fmt.Fprintln(os.Stderr, " /set template <string> Set prompt template")
|
|
||||||
fmt.Fprintln(os.Stderr, " /set history Enable history")
|
fmt.Fprintln(os.Stderr, " /set history Enable history")
|
||||||
fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
|
fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
|
||||||
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
|
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
|
||||||
@@ -204,10 +202,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
|
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
|
||||||
fmt.Println("Set system message.")
|
fmt.Println("Set system message.")
|
||||||
sb.Reset()
|
sb.Reset()
|
||||||
case MultilineTemplate:
|
|
||||||
opts.Template = sb.String()
|
|
||||||
fmt.Println("Set prompt template.")
|
|
||||||
sb.Reset()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
multiline = MultilineNone
|
multiline = MultilineNone
|
||||||
@@ -326,17 +320,13 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
}
|
}
|
||||||
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
|
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
|
||||||
opts.Options[args[2]] = fp[args[2]]
|
opts.Options[args[2]] = fp[args[2]]
|
||||||
case "system", "template":
|
case "system":
|
||||||
if len(args) < 3 {
|
if len(args) < 3 {
|
||||||
usageSet()
|
usageSet()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if args[1] == "system" {
|
multiline = MultilineSystem
|
||||||
multiline = MultilineSystem
|
|
||||||
} else if args[1] == "template" {
|
|
||||||
multiline = MultilineTemplate
|
|
||||||
}
|
|
||||||
|
|
||||||
line := strings.Join(args[2:], " ")
|
line := strings.Join(args[2:], " ")
|
||||||
line, ok := strings.CutPrefix(line, `"""`)
|
line, ok := strings.CutPrefix(line, `"""`)
|
||||||
@@ -356,23 +346,17 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if args[1] == "system" {
|
opts.System = sb.String() // for display in modelfile
|
||||||
opts.System = sb.String() // for display in modelfile
|
newMessage := api.Message{Role: "system", Content: sb.String()}
|
||||||
newMessage := api.Message{Role: "system", Content: sb.String()}
|
// Check if the slice is not empty and the last message is from 'system'
|
||||||
// Check if the slice is not empty and the last message is from 'system'
|
if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
|
||||||
if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
|
// Replace the last message
|
||||||
// Replace the last message
|
opts.Messages[len(opts.Messages)-1] = newMessage
|
||||||
opts.Messages[len(opts.Messages)-1] = newMessage
|
} else {
|
||||||
} else {
|
opts.Messages = append(opts.Messages, newMessage)
|
||||||
opts.Messages = append(opts.Messages, newMessage)
|
|
||||||
}
|
|
||||||
fmt.Println("Set system message.")
|
|
||||||
sb.Reset()
|
|
||||||
} else if args[1] == "template" {
|
|
||||||
opts.Template = sb.String()
|
|
||||||
fmt.Println("Set prompt template.")
|
|
||||||
sb.Reset()
|
|
||||||
}
|
}
|
||||||
|
fmt.Println("Set system message.")
|
||||||
|
sb.Reset()
|
||||||
|
|
||||||
sb.Reset()
|
sb.Reset()
|
||||||
continue
|
continue
|
||||||
@@ -393,7 +377,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
req := &api.ShowRequest{
|
req := &api.ShowRequest{
|
||||||
Name: opts.Model,
|
Name: opts.Model,
|
||||||
System: opts.System,
|
System: opts.System,
|
||||||
Template: opts.Template,
|
|
||||||
Options: opts.Options,
|
Options: opts.Options,
|
||||||
}
|
}
|
||||||
resp, err := client.Show(cmd.Context(), req)
|
resp, err := client.Show(cmd.Context(), req)
|
||||||
@@ -437,12 +420,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
fmt.Println("No system message was specified for this model.")
|
fmt.Println("No system message was specified for this model.")
|
||||||
}
|
}
|
||||||
case "template":
|
case "template":
|
||||||
switch {
|
if resp.Template != "" {
|
||||||
case opts.Template != "":
|
|
||||||
fmt.Println(opts.Template + "\n")
|
|
||||||
case resp.Template != "":
|
|
||||||
fmt.Println(resp.Template)
|
fmt.Println(resp.Template)
|
||||||
default:
|
} else {
|
||||||
fmt.Println("No prompt template was specified for this model.")
|
fmt.Println("No prompt template was specified for this model.")
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@@ -536,10 +516,6 @@ func buildModelfile(opts runOptions) string {
|
|||||||
fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
|
fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.Template != "" {
|
|
||||||
fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template)
|
|
||||||
}
|
|
||||||
|
|
||||||
keys := make([]string, 0)
|
keys := make([]string, 0)
|
||||||
for k := range opts.Options {
|
for k := range opts.Options {
|
||||||
keys = append(keys, k)
|
keys = append(keys, k)
|
||||||
|
@@ -59,7 +59,6 @@ func TestModelfileBuilder(t *testing.T) {
|
|||||||
opts := runOptions{
|
opts := runOptions{
|
||||||
Model: "hork",
|
Model: "hork",
|
||||||
System: "You are part horse and part shark, but all hork. Do horklike things",
|
System: "You are part horse and part shark, but all hork. Do horklike things",
|
||||||
Template: "This is a template.",
|
|
||||||
Messages: []api.Message{
|
Messages: []api.Message{
|
||||||
{Role: "user", Content: "Hey there hork!"},
|
{Role: "user", Content: "Hey there hork!"},
|
||||||
{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
|
{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
|
||||||
@@ -75,7 +74,6 @@ func TestModelfileBuilder(t *testing.T) {
|
|||||||
mf := buildModelfile(opts)
|
mf := buildModelfile(opts)
|
||||||
expectedModelfile := `FROM {{.Model}}
|
expectedModelfile := `FROM {{.Model}}
|
||||||
SYSTEM """{{.System}}"""
|
SYSTEM """{{.System}}"""
|
||||||
TEMPLATE """{{.Template}}"""
|
|
||||||
PARAMETER penalize_newline false
|
PARAMETER penalize_newline false
|
||||||
PARAMETER seed 42
|
PARAMETER seed 42
|
||||||
PARAMETER stop [hi there]
|
PARAMETER stop [hi there]
|
||||||
@@ -97,7 +95,6 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark."""
|
|||||||
mf = buildModelfile(opts)
|
mf = buildModelfile(opts)
|
||||||
expectedModelfile = `FROM {{.ParentModel}}
|
expectedModelfile = `FROM {{.ParentModel}}
|
||||||
SYSTEM """{{.System}}"""
|
SYSTEM """{{.System}}"""
|
||||||
TEMPLATE """{{.Template}}"""
|
|
||||||
PARAMETER penalize_newline false
|
PARAMETER penalize_newline false
|
||||||
PARAMETER seed 42
|
PARAMETER seed 42
|
||||||
PARAMETER stop [hi there]
|
PARAMETER stop [hi there]
|
||||||
|
@@ -272,4 +272,4 @@ The following server settings may be used to adjust how Ollama handles concurren
|
|||||||
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
|
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
|
||||||
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
|
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
|
||||||
|
|
||||||
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
|
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
|
@@ -103,10 +103,6 @@ curl http://localhost:11434/v1/chat/completions \
|
|||||||
- [ ] `user`
|
- [ ] `user`
|
||||||
- [ ] `n`
|
- [ ] `n`
|
||||||
|
|
||||||
#### Notes
|
|
||||||
|
|
||||||
- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
|
|
||||||
|
|
||||||
## Models
|
## Models
|
||||||
|
|
||||||
Before using a model, pull it locally `ollama pull`:
|
Before using a model, pull it locally `ollama pull`:
|
||||||
|
@@ -49,9 +49,17 @@ func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func commonAMDValidateLibDir() (string, error) {
|
func commonAMDValidateLibDir() (string, error) {
|
||||||
// We try to favor system paths first, so that we can wire up the subprocess to use
|
// Favor our bundled version
|
||||||
// the system version. Only use our bundled version if the system version doesn't work
|
|
||||||
// This gives users a more recovery options if versions have subtle problems at runtime
|
// Installer payload location if we're running the installed binary
|
||||||
|
exe, err := os.Executable()
|
||||||
|
if err == nil {
|
||||||
|
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
|
||||||
|
if rocmLibUsable(rocmTargetDir) {
|
||||||
|
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
|
||||||
|
return rocmTargetDir, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Prefer explicit HIP env var
|
// Prefer explicit HIP env var
|
||||||
hipPath := os.Getenv("HIP_PATH")
|
hipPath := os.Getenv("HIP_PATH")
|
||||||
@@ -87,14 +95,5 @@ func commonAMDValidateLibDir() (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Installer payload location if we're running the installed binary
|
|
||||||
exe, err := os.Executable()
|
|
||||||
if err == nil {
|
|
||||||
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
|
|
||||||
if rocmLibUsable(rocmTargetDir) {
|
|
||||||
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
|
|
||||||
return rocmTargetDir, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
||||||
}
|
}
|
||||||
|
@@ -84,9 +84,8 @@ func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("hipDriverGetVersion", "version", version)
|
slog.Debug("hipDriverGetVersion", "version", version)
|
||||||
// TODO - this isn't actually right, but the docs claim hipDriverGetVersion isn't accurate anyway...
|
driverMajor = version / 10000000
|
||||||
driverMajor = version / 1000
|
driverMinor = (version - (driverMajor * 10000000)) / 100000
|
||||||
driverMinor = (version - (driverMajor * 1000)) / 10
|
|
||||||
|
|
||||||
return driverMajor, driverMinor, nil
|
return driverMajor, driverMinor, nil
|
||||||
}
|
}
|
||||||
|
@@ -22,8 +22,8 @@ const (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
// Used to validate if the given ROCm lib is usable
|
// Used to validate if the given ROCm lib is usable
|
||||||
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
|
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // This is not sufficient to discern v5 vs v6
|
||||||
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
|
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\6.1\\bin"} // TODO glob?
|
||||||
)
|
)
|
||||||
|
|
||||||
func AMDGetGPUInfo() []RocmGPUInfo {
|
func AMDGetGPUInfo() []RocmGPUInfo {
|
||||||
@@ -35,12 +35,11 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
|||||||
}
|
}
|
||||||
defer hl.Release()
|
defer hl.Release()
|
||||||
|
|
||||||
// TODO - this reports incorrect version information, so omitting for now
|
driverMajor, driverMinor, err := hl.AMDDriverVersion()
|
||||||
// driverMajor, driverMinor, err := hl.AMDDriverVersion()
|
if err != nil {
|
||||||
// if err != nil {
|
// For now this is benign, but we may eventually need to fail compatibility checks
|
||||||
// // For now this is benign, but we may eventually need to fail compatibility checks
|
slog.Debug("error looking up amd driver version", "error", err)
|
||||||
// slog.Debug("error looking up amd driver version", "error", err)
|
}
|
||||||
// }
|
|
||||||
|
|
||||||
// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
|
// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
|
||||||
count := hl.HipGetDeviceCount()
|
count := hl.HipGetDeviceCount()
|
||||||
@@ -132,10 +131,8 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
|||||||
MinimumMemory: rocmMinimumMemory,
|
MinimumMemory: rocmMinimumMemory,
|
||||||
Name: name,
|
Name: name,
|
||||||
Compute: gfx,
|
Compute: gfx,
|
||||||
|
DriverMajor: driverMajor,
|
||||||
// TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve
|
DriverMinor: driverMinor,
|
||||||
// DriverMajor: driverMajor,
|
|
||||||
// DriverMinor: driverMinor,
|
|
||||||
},
|
},
|
||||||
index: i,
|
index: i,
|
||||||
}
|
}
|
||||||
|
30
gpu/gpu.go
30
gpu/gpu.go
@@ -274,6 +274,28 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
gpuInfo.DriverMajor = driverMajor
|
gpuInfo.DriverMajor = driverMajor
|
||||||
gpuInfo.DriverMinor = driverMinor
|
gpuInfo.DriverMinor = driverMinor
|
||||||
|
|
||||||
|
// query the management library as well so we can record any skew between the two
|
||||||
|
// which represents overhead on the GPU we must set aside on subsequent updates
|
||||||
|
if cHandles.nvml != nil {
|
||||||
|
C.nvml_get_free(*cHandles.nvml, C.int(gpuInfo.index), &memInfo.free, &memInfo.total, &memInfo.used)
|
||||||
|
if memInfo.err != nil {
|
||||||
|
slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
|
||||||
|
C.free(unsafe.Pointer(memInfo.err))
|
||||||
|
} else {
|
||||||
|
if memInfo.free != 0 && uint64(memInfo.free) > gpuInfo.FreeMemory {
|
||||||
|
gpuInfo.OSOverhead = uint64(memInfo.free) - gpuInfo.FreeMemory
|
||||||
|
slog.Info("detected OS VRAM overhead",
|
||||||
|
"id", gpuInfo.ID,
|
||||||
|
"library", gpuInfo.Library,
|
||||||
|
"compute", gpuInfo.Compute,
|
||||||
|
"driver", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor),
|
||||||
|
"name", gpuInfo.Name,
|
||||||
|
"overhead", format.HumanBytes2(gpuInfo.OSOverhead),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
||||||
cudaGPUs = append(cudaGPUs, gpuInfo)
|
cudaGPUs = append(cudaGPUs, gpuInfo)
|
||||||
}
|
}
|
||||||
@@ -338,14 +360,17 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
"before",
|
"before",
|
||||||
"total", format.HumanBytes2(cpus[0].TotalMemory),
|
"total", format.HumanBytes2(cpus[0].TotalMemory),
|
||||||
"free", format.HumanBytes2(cpus[0].FreeMemory),
|
"free", format.HumanBytes2(cpus[0].FreeMemory),
|
||||||
|
"free_swap", format.HumanBytes2(cpus[0].FreeSwap),
|
||||||
),
|
),
|
||||||
slog.Group(
|
slog.Group(
|
||||||
"now",
|
"now",
|
||||||
"total", format.HumanBytes2(mem.TotalMemory),
|
"total", format.HumanBytes2(mem.TotalMemory),
|
||||||
"free", format.HumanBytes2(mem.FreeMemory),
|
"free", format.HumanBytes2(mem.FreeMemory),
|
||||||
|
"free_swap", format.HumanBytes2(mem.FreeSwap),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
cpus[0].FreeMemory = mem.FreeMemory
|
cpus[0].FreeMemory = mem.FreeMemory
|
||||||
|
cpus[0].FreeSwap = mem.FreeSwap
|
||||||
}
|
}
|
||||||
|
|
||||||
var memInfo C.mem_info_t
|
var memInfo C.mem_info_t
|
||||||
@@ -374,9 +399,14 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
slog.Warn("error looking up nvidia GPU memory")
|
slog.Warn("error looking up nvidia GPU memory")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if cHandles.nvml != nil && gpu.OSOverhead > 0 {
|
||||||
|
// When using the management library update based on recorded overhead
|
||||||
|
memInfo.free -= C.uint64_t(gpu.OSOverhead)
|
||||||
|
}
|
||||||
slog.Debug("updating cuda memory data",
|
slog.Debug("updating cuda memory data",
|
||||||
"gpu", gpu.ID,
|
"gpu", gpu.ID,
|
||||||
"name", gpu.Name,
|
"name", gpu.Name,
|
||||||
|
"overhead", format.HumanBytes2(gpu.OSOverhead),
|
||||||
slog.Group(
|
slog.Group(
|
||||||
"before",
|
"before",
|
||||||
"total", format.HumanBytes2(gpu.TotalMemory),
|
"total", format.HumanBytes2(gpu.TotalMemory),
|
||||||
|
@@ -57,6 +57,7 @@ func GetCPUMem() (memInfo, error) {
|
|||||||
return memInfo{
|
return memInfo{
|
||||||
TotalMemory: uint64(C.getPhysicalMemory()),
|
TotalMemory: uint64(C.getPhysicalMemory()),
|
||||||
FreeMemory: uint64(C.getFreeMemory()),
|
FreeMemory: uint64(C.getFreeMemory()),
|
||||||
|
// FreeSwap omitted as Darwin uses dynamic paging
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -50,7 +50,7 @@ var OneapiMgmtName = "libze_intel_gpu.so"
|
|||||||
|
|
||||||
func GetCPUMem() (memInfo, error) {
|
func GetCPUMem() (memInfo, error) {
|
||||||
var mem memInfo
|
var mem memInfo
|
||||||
var total, available, free, buffers, cached uint64
|
var total, available, free, buffers, cached, freeSwap uint64
|
||||||
f, err := os.Open("/proc/meminfo")
|
f, err := os.Open("/proc/meminfo")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return mem, err
|
return mem, err
|
||||||
@@ -70,20 +70,21 @@ func GetCPUMem() (memInfo, error) {
|
|||||||
_, err = fmt.Sscanf(line, "Buffers:%d", &buffers)
|
_, err = fmt.Sscanf(line, "Buffers:%d", &buffers)
|
||||||
case strings.HasPrefix(line, "Cached:"):
|
case strings.HasPrefix(line, "Cached:"):
|
||||||
_, err = fmt.Sscanf(line, "Cached:%d", &cached)
|
_, err = fmt.Sscanf(line, "Cached:%d", &cached)
|
||||||
|
case strings.HasPrefix(line, "SwapFree:"):
|
||||||
|
_, err = fmt.Sscanf(line, "SwapFree:%d", &freeSwap)
|
||||||
default:
|
default:
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return mem, err
|
return mem, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if total > 0 && available > 0 {
|
|
||||||
mem.TotalMemory = total * format.KibiByte
|
|
||||||
mem.FreeMemory = available * format.KibiByte
|
|
||||||
return mem, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
mem.TotalMemory = total * format.KibiByte
|
mem.TotalMemory = total * format.KibiByte
|
||||||
mem.FreeMemory = (free + buffers + cached) * format.KibiByte
|
mem.FreeSwap = freeSwap * format.KibiByte
|
||||||
|
if available > 0 {
|
||||||
|
mem.FreeMemory = available * format.KibiByte
|
||||||
|
} else {
|
||||||
|
mem.FreeMemory = (free + buffers + cached) * format.KibiByte
|
||||||
|
}
|
||||||
return mem, nil
|
return mem, nil
|
||||||
}
|
}
|
||||||
|
@@ -51,5 +51,5 @@ func GetCPUMem() (memInfo, error) {
|
|||||||
if r1 == 0 {
|
if r1 == 0 {
|
||||||
return memInfo{}, fmt.Errorf("GlobalMemoryStatusEx failed: %w", err)
|
return memInfo{}, fmt.Errorf("GlobalMemoryStatusEx failed: %w", err)
|
||||||
}
|
}
|
||||||
return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys}, nil
|
return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys, FreeSwap: memStatus.AvailPageFile}, nil
|
||||||
}
|
}
|
||||||
|
@@ -10,6 +10,7 @@ import (
|
|||||||
type memInfo struct {
|
type memInfo struct {
|
||||||
TotalMemory uint64 `json:"total_memory,omitempty"`
|
TotalMemory uint64 `json:"total_memory,omitempty"`
|
||||||
FreeMemory uint64 `json:"free_memory,omitempty"`
|
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||||
|
FreeSwap uint64 `json:"free_swap,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Beginning of an `ollama info` command
|
// Beginning of an `ollama info` command
|
||||||
@@ -52,7 +53,8 @@ type CPUInfo struct {
|
|||||||
|
|
||||||
type CudaGPUInfo struct {
|
type CudaGPUInfo struct {
|
||||||
GpuInfo
|
GpuInfo
|
||||||
index int //nolint:unused,nolintlint
|
OSOverhead uint64 // Memory overhead between the driver library and management library
|
||||||
|
index int //nolint:unused,nolintlint
|
||||||
}
|
}
|
||||||
type CudaGPUInfoList []CudaGPUInfo
|
type CudaGPUInfoList []CudaGPUInfo
|
||||||
|
|
||||||
|
152
integration/embed_test.go
Normal file
152
integration/embed_test.go
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAllMiniLMEmbed(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := embedTestHelper(ctx, t, req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings) != 1 {
|
||||||
|
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings[0]) != 384 {
|
||||||
|
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Embeddings[0][0] != 0.010071031 {
|
||||||
|
t.Fatalf("expected 0.010071031, got %f", res.Embeddings[0][0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: []string{"why is the sky blue?", "why is the grass green?"},
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := embedTestHelper(ctx, t, req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings) != 2 {
|
||||||
|
t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings[0]) != 384 {
|
||||||
|
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Embeddings[0][0] != 0.010071031 || res.Embeddings[1][0] != -0.009802706 {
|
||||||
|
t.Fatalf("expected 0.010071031 and -0.009802706, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllMiniLmEmbedTruncate(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
truncTrue, truncFalse := true, false
|
||||||
|
|
||||||
|
type testReq struct {
|
||||||
|
Name string
|
||||||
|
Request api.EmbedRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
reqs := []testReq{
|
||||||
|
{
|
||||||
|
Name: "Target Truncation",
|
||||||
|
Request: api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Default Truncate",
|
||||||
|
Request: api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
Options: map[string]any{"num_ctx": 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Explicit Truncate",
|
||||||
|
Request: api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
Truncate: &truncTrue,
|
||||||
|
Options: map[string]any{"num_ctx": 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
res := make(map[string]*api.EmbedResponse)
|
||||||
|
|
||||||
|
for _, req := range reqs {
|
||||||
|
response, err := embedTestHelper(ctx, t, req.Request)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
res[req.Name] = response
|
||||||
|
}
|
||||||
|
|
||||||
|
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
|
||||||
|
t.Fatal("expected default request to truncate correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
|
||||||
|
t.Fatal("expected default request and truncate true request to be the same")
|
||||||
|
}
|
||||||
|
|
||||||
|
// check that truncate set to false returns an error if context length is exceeded
|
||||||
|
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
Truncate: &truncFalse,
|
||||||
|
Options: map[string]any{"num_ctx": 1},
|
||||||
|
})
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := client.Embed(ctx, &req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, nil
|
||||||
|
}
|
37
llm/ext_server/server.cpp
vendored
37
llm/ext_server/server.cpp
vendored
@@ -3188,26 +3188,33 @@ int main(int argc, char **argv) {
|
|||||||
prompt = "";
|
prompt = "";
|
||||||
}
|
}
|
||||||
|
|
||||||
json image_data;
|
if (prompt.size() == 1) {
|
||||||
if (body.count("image_data") != 0) {
|
prompt = prompt[0];
|
||||||
image_data = body["image_data"];
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
image_data = "";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// create and queue the task
|
// create and queue the task
|
||||||
const int task_id = llama.queue_tasks.get_new_id();
|
json responses;
|
||||||
llama.queue_results.add_waiting_task_id(task_id);
|
{
|
||||||
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, true, -1);
|
const int id_task = llama.queue_tasks.get_new_id();
|
||||||
|
llama.queue_results.add_waiting_task_id(id_task);
|
||||||
|
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
|
||||||
|
|
||||||
// get the result
|
// get the result
|
||||||
task_result result = llama.queue_results.recv(task_id);
|
task_result result = llama.queue_results.recv(id_task);
|
||||||
llama.queue_results.remove_waiting_task_id(task_id);
|
llama.queue_results.remove_waiting_task_id(id_task);
|
||||||
|
if (result.error) {
|
||||||
|
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
||||||
|
}
|
||||||
|
|
||||||
// send the result
|
responses = result.result_json.value("results", std::vector<json>{result.result_json});
|
||||||
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
json embeddings = json::array();
|
||||||
|
for (auto & elem : responses) {
|
||||||
|
embeddings.push_back(elem.at("embedding"));
|
||||||
|
}
|
||||||
|
// send the result
|
||||||
|
json embedding_res = json{{"embedding", embeddings}};
|
||||||
|
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
|
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
|
||||||
|
@@ -178,7 +178,7 @@ if [ -z "${OLLAMA_SKIP_CUDA_GENERATE}" -a -d "${CUDA_LIB_DIR}" ]; then
|
|||||||
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
|
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
|
||||||
echo "Building custom CUDA GPU"
|
echo "Building custom CUDA GPU"
|
||||||
else
|
else
|
||||||
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DGGML_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} -DCMAKE_LIBRARY_PATH=/usr/local/cuda/compat"
|
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}"
|
||||||
fi
|
fi
|
||||||
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}"
|
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}"
|
||||||
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
|
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
|
||||||
|
@@ -6,18 +6,9 @@ function amdGPUs {
|
|||||||
if ($env:AMDGPU_TARGETS) {
|
if ($env:AMDGPU_TARGETS) {
|
||||||
return $env:AMDGPU_TARGETS
|
return $env:AMDGPU_TARGETS
|
||||||
}
|
}
|
||||||
# TODO - load from some common data file for linux + windows build consistency
|
# Current supported rocblas list from ROCm v6.1.2 on windows
|
||||||
$GPU_LIST = @(
|
$GPU_LIST = @(
|
||||||
"gfx900"
|
|
||||||
"gfx906:xnack-"
|
"gfx906:xnack-"
|
||||||
"gfx908:xnack-"
|
|
||||||
"gfx90a:xnack+"
|
|
||||||
"gfx90a:xnack-"
|
|
||||||
"gfx940"
|
|
||||||
"gfx941"
|
|
||||||
"gfx942"
|
|
||||||
"gfx1010"
|
|
||||||
"gfx1012"
|
|
||||||
"gfx1030"
|
"gfx1030"
|
||||||
"gfx1100"
|
"gfx1100"
|
||||||
"gfx1101"
|
"gfx1101"
|
||||||
@@ -395,7 +386,6 @@ function build_rocm() {
|
|||||||
sign
|
sign
|
||||||
install
|
install
|
||||||
|
|
||||||
# Assumes v5.7, may need adjustments for v6
|
|
||||||
rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
|
rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
|
||||||
md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
|
md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
|
||||||
cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
|
cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
|
||||||
|
26
llm/ggml.go
26
llm/ggml.go
@@ -424,6 +424,32 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
|||||||
4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
|
4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
|
||||||
4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
|
4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
|
||||||
)
|
)
|
||||||
|
case "chatglm":
|
||||||
|
fullOffload = 4 * batch * (embedding + vocab)
|
||||||
|
partialOffload = 4*batch*(embedding+vocab) + embedding*vocab*105/128
|
||||||
|
if qkvBias, ok := layers["blk.0"]["attn_qkv.bias"]; ok {
|
||||||
|
fullOffload = max(
|
||||||
|
fullOffload,
|
||||||
|
4*batch*(2+
|
||||||
|
2*embedding+
|
||||||
|
context+
|
||||||
|
context*heads+
|
||||||
|
embeddingHeadsK*heads+
|
||||||
|
qkvBias.Shape[0]),
|
||||||
|
)
|
||||||
|
|
||||||
|
partialOffload = max(
|
||||||
|
partialOffload,
|
||||||
|
4*batch*(1+
|
||||||
|
2*embedding+
|
||||||
|
embeddingHeadsK*heads+
|
||||||
|
context+
|
||||||
|
context*heads)+
|
||||||
|
4*embeddingHeadsK*context+
|
||||||
|
4*context*embeddingHeadsK+
|
||||||
|
4*qkvBias.Shape[0],
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@@ -537,6 +537,7 @@ var ggufKVOrder = map[string][]string{
|
|||||||
"tokenizer.ggml.add_bos_token",
|
"tokenizer.ggml.add_bos_token",
|
||||||
"tokenizer.ggml.add_eos_token",
|
"tokenizer.ggml.add_eos_token",
|
||||||
"tokenizer.chat_template",
|
"tokenizer.chat_template",
|
||||||
|
"bert.pooling_type",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -33,7 +33,7 @@ func Quantize(infile, outfile string, ftype fileType) error {
|
|||||||
params.ftype = ftype.Value()
|
params.ftype = ftype.Value()
|
||||||
|
|
||||||
if rc := C.llama_model_quantize(cinfile, coutfile, ¶ms); rc != 0 {
|
if rc := C.llama_model_quantize(cinfile, coutfile, ¶ms); rc != 0 {
|
||||||
return fmt.Errorf("llama_model_quantize: %d", rc)
|
return fmt.Errorf("failed to quantize model. This model architecture may not be supported, or you may need to upgrade Ollama to the latest version")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@@ -33,7 +33,7 @@ type LlamaServer interface {
|
|||||||
Ping(ctx context.Context) error
|
Ping(ctx context.Context) error
|
||||||
WaitUntilRunning(ctx context.Context) error
|
WaitUntilRunning(ctx context.Context) error
|
||||||
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||||
Embedding(ctx context.Context, prompt string) ([]float64, error)
|
Embed(ctx context.Context, input []string) ([][]float32, error)
|
||||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||||
Close() error
|
Close() error
|
||||||
@@ -88,6 +88,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
var estimate MemoryEstimate
|
var estimate MemoryEstimate
|
||||||
var systemTotalMemory uint64
|
var systemTotalMemory uint64
|
||||||
var systemFreeMemory uint64
|
var systemFreeMemory uint64
|
||||||
|
var systemSwapFreeMemory uint64
|
||||||
|
|
||||||
systemMemInfo, err := gpu.GetCPUMem()
|
systemMemInfo, err := gpu.GetCPUMem()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -95,7 +96,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
} else {
|
} else {
|
||||||
systemTotalMemory = systemMemInfo.TotalMemory
|
systemTotalMemory = systemMemInfo.TotalMemory
|
||||||
systemFreeMemory = systemMemInfo.FreeMemory
|
systemFreeMemory = systemMemInfo.FreeMemory
|
||||||
slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", systemFreeMemory)
|
systemSwapFreeMemory = systemMemInfo.FreeSwap
|
||||||
|
slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
|
// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
|
||||||
@@ -122,6 +124,16 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// On linux, over-allocating CPU memory will almost always result in an error
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
|
||||||
|
available := systemFreeMemory + systemSwapFreeMemory
|
||||||
|
if systemMemoryRequired > available {
|
||||||
|
slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", available, "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "swap", format.HumanBytes2(systemSwapFreeMemory))
|
||||||
|
return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
estimate.log()
|
estimate.log()
|
||||||
|
|
||||||
// Loop through potential servers
|
// Loop through potential servers
|
||||||
@@ -254,10 +266,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
params = append(params, "--tensor-split", estimate.TensorSplit)
|
params = append(params, "--tensor-split", estimate.TensorSplit)
|
||||||
}
|
}
|
||||||
|
|
||||||
if estimate.TensorSplit != "" {
|
|
||||||
params = append(params, "--tensor-split", estimate.TensorSplit)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range len(servers) {
|
for i := range len(servers) {
|
||||||
dir := availableServers[servers[i]]
|
dir := availableServers[servers[i]]
|
||||||
if dir == "" {
|
if dir == "" {
|
||||||
@@ -859,15 +867,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingRequest struct {
|
type EmbedRequest struct {
|
||||||
Content string `json:"content"`
|
Content []string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingResponse struct {
|
type EmbedResponse struct {
|
||||||
Embedding []float64 `json:"embedding"`
|
Embedding [][]float32 `json:"embedding"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
|
||||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||||
slog.Error("Failed to acquire semaphore", "error", err)
|
slog.Error("Failed to acquire semaphore", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -882,7 +890,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
|
|||||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(TokenizeRequest{Content: prompt})
|
data, err := json.Marshal(EmbedRequest{Content: input})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||||
}
|
}
|
||||||
@@ -909,7 +917,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
|
|||||||
return nil, fmt.Errorf("%s", body)
|
return nil, fmt.Errorf("%s", body)
|
||||||
}
|
}
|
||||||
|
|
||||||
var embedding EmbeddingResponse
|
var embedding EmbedResponse
|
||||||
if err := json.Unmarshal(body, &embedding); err != nil {
|
if err := json.Unmarshal(body, &embedding); err != nil {
|
||||||
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||||
}
|
}
|
||||||
|
244
openai/openai.go
244
openai/openai.go
@@ -3,11 +3,14 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -27,8 +30,9 @@ type ErrorResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content any `json:"content"`
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Choice struct {
|
type Choice struct {
|
||||||
@@ -59,6 +63,11 @@ type ResponseFormat struct {
|
|||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmbedRequest struct {
|
||||||
|
Input any `json:"input"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
}
|
||||||
|
|
||||||
type ChatCompletionRequest struct {
|
type ChatCompletionRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
@@ -71,6 +80,7 @@ type ChatCompletionRequest struct {
|
|||||||
PresencePenalty *float64 `json:"presence_penalty_penalty"`
|
PresencePenalty *float64 `json:"presence_penalty_penalty"`
|
||||||
TopP *float64 `json:"top_p"`
|
TopP *float64 `json:"top_p"`
|
||||||
ResponseFormat *ResponseFormat `json:"response_format"`
|
ResponseFormat *ResponseFormat `json:"response_format"`
|
||||||
|
Tools []api.Tool `json:"tools"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletion struct {
|
type ChatCompletion struct {
|
||||||
@@ -104,6 +114,7 @@ type CompletionRequest struct {
|
|||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
Temperature *float32 `json:"temperature"`
|
Temperature *float32 `json:"temperature"`
|
||||||
TopP float32 `json:"top_p"`
|
TopP float32 `json:"top_p"`
|
||||||
|
Suffix string `json:"suffix"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Completion struct {
|
type Completion struct {
|
||||||
@@ -125,6 +136,15 @@ type CompletionChunk struct {
|
|||||||
SystemFingerprint string `json:"system_fingerprint"`
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments string `json:"arguments"`
|
||||||
|
} `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
@@ -132,11 +152,23 @@ type Model struct {
|
|||||||
OwnedBy string `json:"owned_by"`
|
OwnedBy string `json:"owned_by"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Embedding struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Embedding []float32 `json:"embedding"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
type ListCompletion struct {
|
type ListCompletion struct {
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
Data []Model `json:"data"`
|
Data []Model `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmbeddingList struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []Embedding `json:"data"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
}
|
||||||
|
|
||||||
func NewError(code int, message string) ErrorResponse {
|
func NewError(code int, message string) ErrorResponse {
|
||||||
var etype string
|
var etype string
|
||||||
switch code {
|
switch code {
|
||||||
@@ -151,7 +183,31 @@ func NewError(code int, message string) ErrorResponse {
|
|||||||
return ErrorResponse{Error{Type: etype, Message: message}}
|
return ErrorResponse{Error{Type: etype, Message: message}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toolCallId() string {
|
||||||
|
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
|
b := make([]byte, 8)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
||||||
|
}
|
||||||
|
return "call_" + strings.ToLower(string(b))
|
||||||
|
}
|
||||||
|
|
||||||
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||||
|
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
|
||||||
|
for i, tc := range r.Message.ToolCalls {
|
||||||
|
toolCalls[i].ID = toolCallId()
|
||||||
|
toolCalls[i].Type = "function"
|
||||||
|
toolCalls[i].Function.Name = tc.Function.Name
|
||||||
|
|
||||||
|
args, err := json.Marshal(tc.Function.Arguments)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("could not marshall function arguments to json", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls[i].Function.Arguments = string(args)
|
||||||
|
}
|
||||||
|
|
||||||
return ChatCompletion{
|
return ChatCompletion{
|
||||||
Id: id,
|
Id: id,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
@@ -160,7 +216,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|||||||
SystemFingerprint: "fp_ollama",
|
SystemFingerprint: "fp_ollama",
|
||||||
Choices: []Choice{{
|
Choices: []Choice{{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Message: Message{Role: r.Message.Role, Content: r.Message.Content},
|
Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
|
||||||
FinishReason: func(reason string) *string {
|
FinishReason: func(reason string) *string {
|
||||||
if len(reason) > 0 {
|
if len(reason) > 0 {
|
||||||
return &reason
|
return &reason
|
||||||
@@ -169,7 +225,6 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|||||||
}(r.DoneReason),
|
}(r.DoneReason),
|
||||||
}},
|
}},
|
||||||
Usage: Usage{
|
Usage: Usage{
|
||||||
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
|
|
||||||
PromptTokens: r.PromptEvalCount,
|
PromptTokens: r.PromptEvalCount,
|
||||||
CompletionTokens: r.EvalCount,
|
CompletionTokens: r.EvalCount,
|
||||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||||
@@ -215,7 +270,6 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
|
|||||||
}(r.DoneReason),
|
}(r.DoneReason),
|
||||||
}},
|
}},
|
||||||
Usage: Usage{
|
Usage: Usage{
|
||||||
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
|
|
||||||
PromptTokens: r.PromptEvalCount,
|
PromptTokens: r.PromptEvalCount,
|
||||||
CompletionTokens: r.EvalCount,
|
CompletionTokens: r.EvalCount,
|
||||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||||
@@ -260,6 +314,27 @@ func toListCompletion(r api.ListResponse) ListCompletion {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
||||||
|
if r.Embeddings != nil {
|
||||||
|
var data []Embedding
|
||||||
|
for i, e := range r.Embeddings {
|
||||||
|
data = append(data, Embedding{
|
||||||
|
Object: "embedding",
|
||||||
|
Embedding: e,
|
||||||
|
Index: i,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return EmbeddingList{
|
||||||
|
Object: "list",
|
||||||
|
Data: data,
|
||||||
|
Model: model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return EmbeddingList{}
|
||||||
|
}
|
||||||
|
|
||||||
func toModel(r api.ShowResponse, m string) Model {
|
func toModel(r api.ShowResponse, m string) Model {
|
||||||
return Model{
|
return Model{
|
||||||
Id: m,
|
Id: m,
|
||||||
@@ -269,10 +344,78 @@ func toModel(r api.ShowResponse, m string) Model {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
|
func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||||
var messages []api.Message
|
var messages []api.Message
|
||||||
for _, msg := range r.Messages {
|
for _, msg := range r.Messages {
|
||||||
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
|
switch content := msg.Content.(type) {
|
||||||
|
case string:
|
||||||
|
messages = append(messages, api.Message{Role: msg.Role, Content: content})
|
||||||
|
case []any:
|
||||||
|
message := api.Message{Role: msg.Role}
|
||||||
|
for _, c := range content {
|
||||||
|
data, ok := c.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid message format")
|
||||||
|
}
|
||||||
|
switch data["type"] {
|
||||||
|
case "text":
|
||||||
|
text, ok := data["text"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid message format")
|
||||||
|
}
|
||||||
|
message.Content = text
|
||||||
|
case "image_url":
|
||||||
|
var url string
|
||||||
|
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
||||||
|
if url, ok = urlMap["url"].(string); !ok {
|
||||||
|
return nil, fmt.Errorf("invalid message format")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if url, ok = data["image_url"].(string); !ok {
|
||||||
|
return nil, fmt.Errorf("invalid message format")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
types := []string{"jpeg", "jpg", "png"}
|
||||||
|
valid := false
|
||||||
|
for _, t := range types {
|
||||||
|
prefix := "data:image/" + t + ";base64,"
|
||||||
|
if strings.HasPrefix(url, prefix) {
|
||||||
|
url = strings.TrimPrefix(url, prefix)
|
||||||
|
valid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
return nil, fmt.Errorf("invalid image input")
|
||||||
|
}
|
||||||
|
|
||||||
|
img, err := base64.StdEncoding.DecodeString(url)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid message format")
|
||||||
|
}
|
||||||
|
message.Images = append(message.Images, img)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid message format")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messages = append(messages, message)
|
||||||
|
default:
|
||||||
|
if msg.ToolCalls == nil {
|
||||||
|
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
|
||||||
|
for i, tc := range msg.ToolCalls {
|
||||||
|
toolCalls[i].Function.Name = tc.Function.Name
|
||||||
|
err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid tool call arguments")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
options := make(map[string]interface{})
|
options := make(map[string]interface{})
|
||||||
@@ -323,13 +466,14 @@ func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
|
|||||||
format = "json"
|
format = "json"
|
||||||
}
|
}
|
||||||
|
|
||||||
return api.ChatRequest{
|
return &api.ChatRequest{
|
||||||
Model: r.Model,
|
Model: r.Model,
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
Format: format,
|
Format: format,
|
||||||
Options: options,
|
Options: options,
|
||||||
Stream: &r.Stream,
|
Stream: &r.Stream,
|
||||||
}
|
Tools: r.Tools,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||||
@@ -379,6 +523,7 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
|||||||
Prompt: r.Prompt,
|
Prompt: r.Prompt,
|
||||||
Options: options,
|
Options: options,
|
||||||
Stream: &r.Stream,
|
Stream: &r.Stream,
|
||||||
|
Suffix: r.Suffix,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -407,6 +552,11 @@ type RetrieveWriter struct {
|
|||||||
model string
|
model string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmbedWriter struct {
|
||||||
|
BaseWriter
|
||||||
|
model string
|
||||||
|
}
|
||||||
|
|
||||||
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
|
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
|
||||||
var serr api.StatusError
|
var serr api.StatusError
|
||||||
err := json.Unmarshal(data, &serr)
|
err := json.Unmarshal(data, &serr)
|
||||||
@@ -572,6 +722,33 @@ func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
|||||||
return w.writeResponse(data)
|
return w.writeResponse(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var embedResponse api.EmbedResponse
|
||||||
|
err := json.Unmarshal(data, &embedResponse)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(code, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
func ListMiddleware() gin.HandlerFunc {
|
func ListMiddleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
w := &ListWriter{
|
w := &ListWriter{
|
||||||
@@ -635,6 +812,47 @@ func CompletionsMiddleware() gin.HandlerFunc {
|
|||||||
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func EmbeddingsMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var req EmbedRequest
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Input == "" {
|
||||||
|
req.Input = []string{""}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Input == nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := req.Input.([]any); ok && len(v) == 0 {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
w := &EmbedWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
model: req.Model,
|
||||||
|
}
|
||||||
|
|
||||||
c.Writer = w
|
c.Writer = w
|
||||||
|
|
||||||
c.Next()
|
c.Next()
|
||||||
@@ -656,7 +874,13 @@ func ChatMiddleware() gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
|
|
||||||
|
chatReq, err := fromChatRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@@ -2,6 +2,7 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -15,6 +16,10 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const prefix = `data:image/jpeg;base64,`
|
||||||
|
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||||
|
const imageURL = prefix + image
|
||||||
|
|
||||||
func TestMiddlewareRequests(t *testing.T) {
|
func TestMiddlewareRequests(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
Name string
|
Name string
|
||||||
@@ -80,6 +85,7 @@ func TestMiddlewareRequests(t *testing.T) {
|
|||||||
Prompt: "Hello",
|
Prompt: "Hello",
|
||||||
Temperature: &temp,
|
Temperature: &temp,
|
||||||
Stop: []string{"\n", "stop"},
|
Stop: []string{"\n", "stop"},
|
||||||
|
Suffix: "suffix",
|
||||||
}
|
}
|
||||||
|
|
||||||
bodyBytes, _ := json.Marshal(body)
|
bodyBytes, _ := json.Marshal(body)
|
||||||
@@ -110,6 +116,126 @@ func TestMiddlewareRequests(t *testing.T) {
|
|||||||
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
|
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
|
||||||
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
|
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if genReq.Suffix != "suffix" {
|
||||||
|
t.Fatalf("expected 'suffix', got %s", genReq.Suffix)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "chat handler with image content",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/chat",
|
||||||
|
Handler: ChatMiddleware,
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := ChatCompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []Message{
|
||||||
|
{
|
||||||
|
Role: "user", Content: []map[string]any{
|
||||||
|
{"type": "text", "text": "Hello"},
|
||||||
|
{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *http.Request) {
|
||||||
|
var chatReq api.ChatRequest
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chatReq.Messages[0].Role != "user" {
|
||||||
|
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chatReq.Messages[0].Content != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
|
||||||
|
|
||||||
|
if !bytes.Equal(chatReq.Messages[0].Images[0], img) {
|
||||||
|
t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "embed handler single input",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/embed",
|
||||||
|
Handler: EmbeddingsMiddleware,
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := EmbedRequest{
|
||||||
|
Input: "Hello",
|
||||||
|
Model: "test-model",
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *http.Request) {
|
||||||
|
var embedReq api.EmbedRequest
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if embedReq.Input != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %s", embedReq.Input)
|
||||||
|
}
|
||||||
|
|
||||||
|
if embedReq.Model != "test-model" {
|
||||||
|
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "embed handler batch input",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/embed",
|
||||||
|
Handler: EmbeddingsMiddleware,
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := EmbedRequest{
|
||||||
|
Input: []string{"Hello", "World"},
|
||||||
|
Model: "test-model",
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *http.Request) {
|
||||||
|
var embedReq api.EmbedRequest
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
input, ok := embedReq.Input.([]any)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected input to be a list")
|
||||||
|
}
|
||||||
|
|
||||||
|
if input[0].(string) != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %s", input[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if input[1].(string) != "World" {
|
||||||
|
t.Fatalf("expected 'World', got %s", input[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
if embedReq.Model != "test-model" {
|
||||||
|
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
|
||||||
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@@ -107,9 +107,12 @@ function gatherDependencies() {
|
|||||||
|
|
||||||
# TODO - this varies based on host build system and MSVC version - drive from dumpbin output
|
# TODO - this varies based on host build system and MSVC version - drive from dumpbin output
|
||||||
# currently works for Win11 + MSVC 2019 + Cuda V11
|
# currently works for Win11 + MSVC 2019 + Cuda V11
|
||||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140.dll" "${script:DEPS_DIR}\ollama_runners\"
|
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140*.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\ollama_runners\"
|
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\ollama_runners\"
|
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||||
|
foreach ($part in $("runtime", "stdio", "filesystem", "math", "convert", "heap", "string", "time", "locale", "environment")) {
|
||||||
|
cp "$env:VCToolsRedistDir\..\..\..\Tools\Llvm\x64\bin\api-ms-win-crt-${part}*.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"
|
cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"
|
||||||
|
@@ -34,11 +34,20 @@ import (
|
|||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errCapabilityCompletion = errors.New("completion")
|
var (
|
||||||
|
errCapabilities = errors.New("does not support")
|
||||||
|
errCapabilityCompletion = errors.New("completion")
|
||||||
|
errCapabilityTools = errors.New("tools")
|
||||||
|
errCapabilityInsert = errors.New("insert")
|
||||||
|
)
|
||||||
|
|
||||||
type Capability string
|
type Capability string
|
||||||
|
|
||||||
const CapabilityCompletion = Capability("completion")
|
const (
|
||||||
|
CapabilityCompletion = Capability("completion")
|
||||||
|
CapabilityTools = Capability("tools")
|
||||||
|
CapabilityInsert = Capability("insert")
|
||||||
|
)
|
||||||
|
|
||||||
type registryOptions struct {
|
type registryOptions struct {
|
||||||
Insecure bool
|
Insecure bool
|
||||||
@@ -88,6 +97,15 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
|
|||||||
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
|
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
|
||||||
errs = append(errs, errCapabilityCompletion)
|
errs = append(errs, errCapabilityCompletion)
|
||||||
}
|
}
|
||||||
|
case CapabilityTools:
|
||||||
|
if !slices.Contains(m.Template.Vars(), "tools") {
|
||||||
|
errs = append(errs, errCapabilityTools)
|
||||||
|
}
|
||||||
|
case CapabilityInsert:
|
||||||
|
vars := m.Template.Vars()
|
||||||
|
if !slices.Contains(vars, "suffix") {
|
||||||
|
errs = append(errs, errCapabilityInsert)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
slog.Error("unknown capability", "capability", cap)
|
slog.Error("unknown capability", "capability", cap)
|
||||||
return fmt.Errorf("unknown capability: %s", cap)
|
return fmt.Errorf("unknown capability: %s", cap)
|
||||||
@@ -95,7 +113,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := errors.Join(errs...); err != nil {
|
if err := errors.Join(errs...); err != nil {
|
||||||
return fmt.Errorf("missing capabilities: %w", errors.Join(errs...))
|
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@@ -4,6 +4,7 @@ import (
|
|||||||
"archive/zip"
|
"archive/zip"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -11,6 +12,9 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"text/template/parse"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/convert"
|
"github.com/ollama/ollama/convert"
|
||||||
@@ -289,3 +293,91 @@ func detectContentType(r io.Reader) (string, error) {
|
|||||||
|
|
||||||
return "unknown", nil
|
return "unknown", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||||
|
// mxyng: this only really works if the input contains tool calls in some JSON format
|
||||||
|
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||||
|
// create a subtree from the node that ranges over .ToolCalls
|
||||||
|
tmpl := m.Template.Subtree(func(n parse.Node) bool {
|
||||||
|
if t, ok := n.(*parse.RangeNode); ok {
|
||||||
|
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
|
||||||
|
if tmpl == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
|
||||||
|
"ToolCalls": {
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "@@name@@",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"@@argument@@": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var kv map[string]any
|
||||||
|
// execute the subtree with placeholders to identify the keys
|
||||||
|
// trim any commands that might exist in the template
|
||||||
|
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the keys that correspond to the name and arguments fields
|
||||||
|
var name, arguments string
|
||||||
|
for k, v := range kv {
|
||||||
|
switch v.(type) {
|
||||||
|
case string:
|
||||||
|
name = k
|
||||||
|
case map[string]any:
|
||||||
|
arguments = k
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var objs []map[string]any
|
||||||
|
for offset := 0; offset < len(s); {
|
||||||
|
var obj map[string]any
|
||||||
|
decoder := json.NewDecoder(strings.NewReader(s[offset:]))
|
||||||
|
if err := decoder.Decode(&obj); errors.Is(err, io.EOF) {
|
||||||
|
break
|
||||||
|
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
|
||||||
|
// skip over any syntax errors
|
||||||
|
offset += int(syntax.Offset)
|
||||||
|
} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
|
||||||
|
// skip over any unmarshalable types
|
||||||
|
offset += int(unmarshalType.Offset)
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, false
|
||||||
|
} else {
|
||||||
|
offset += int(decoder.InputOffset())
|
||||||
|
objs = append(objs, obj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var toolCalls []api.ToolCall
|
||||||
|
for _, kv := range objs {
|
||||||
|
var call api.ToolCall
|
||||||
|
for k, v := range kv {
|
||||||
|
switch k {
|
||||||
|
case name:
|
||||||
|
call.Function.Name = v.(string)
|
||||||
|
case arguments:
|
||||||
|
call.Function.Arguments = v.(map[string]any)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls = append(toolCalls, call)
|
||||||
|
}
|
||||||
|
|
||||||
|
return toolCalls, len(toolCalls) > 0
|
||||||
|
}
|
||||||
|
@@ -3,7 +3,9 @@ package server
|
|||||||
import (
|
import (
|
||||||
"archive/zip"
|
"archive/zip"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -11,7 +13,9 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
func createZipFile(t *testing.T, name string) *os.File {
|
func createZipFile(t *testing.T, name string) *os.File {
|
||||||
@@ -110,3 +114,122 @@ func TestExtractFromZipFile(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
bts, err := os.ReadFile(filepath.Join(base, name))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.NewBuffer(bts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecuteWithTools(t *testing.T) {
|
||||||
|
p := filepath.Join("testdata", "tools")
|
||||||
|
cases := []struct {
|
||||||
|
model string
|
||||||
|
output string
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||||
|
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||||
|
|
||||||
|
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
|
||||||
|
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
||||||
|
|
||||||
|
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||||
|
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||||
|
{"command-r-plus", "Action: ```json" + `
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tool_name": "get_current_weather",
|
||||||
|
"parameters": {
|
||||||
|
"format": "fahrenheit",
|
||||||
|
"location": "San Francisco, CA"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tool_name": "get_current_weather",
|
||||||
|
"parameters": {
|
||||||
|
"format": "celsius",
|
||||||
|
"location": "Toronto, Canada"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
` + "```", true},
|
||||||
|
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||||
|
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||||
|
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||||
|
{"llama3-groq-tool-use", `<tool_call>
|
||||||
|
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
|
||||||
|
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
|
||||||
|
</tool_call>`, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
var tools []api.Tool
|
||||||
|
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var messages []api.Message
|
||||||
|
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
calls := []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"format": "fahrenheit",
|
||||||
|
"location": "San Francisco, CA",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"format": "celsius",
|
||||||
|
"location": "Toronto, Canada",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.model, func(t *testing.T) {
|
||||||
|
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("template", func(t *testing.T) {
|
||||||
|
var actual bytes.Buffer
|
||||||
|
if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("parse", func(t *testing.T) {
|
||||||
|
m := &Model{Template: tmpl}
|
||||||
|
actual, ok := m.parseToolCalls(tt.output)
|
||||||
|
if ok != tt.ok {
|
||||||
|
t.Fatalf("expected %t, got %t", tt.ok, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.ok {
|
||||||
|
if diff := cmp.Diff(actual, calls); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -4,7 +4,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"slices"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
@@ -16,29 +15,21 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
|
|||||||
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
||||||
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||||
// latest message and 2) system messages
|
// latest message and 2) system messages
|
||||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
|
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
|
||||||
// pull out any system messages which should always be included in the prompt
|
|
||||||
var system []api.Message
|
var system []api.Message
|
||||||
msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
|
|
||||||
if m.Role == "system" {
|
|
||||||
system = append(system, m)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
})
|
|
||||||
|
|
||||||
if len(system) == 0 && m.System != "" {
|
|
||||||
// add model system prompt since it wasn't provided
|
|
||||||
system = append(system, api.Message{Role: "system", Content: m.System})
|
|
||||||
}
|
|
||||||
|
|
||||||
// always include the last message
|
// always include the last message
|
||||||
n := len(msgs) - 1
|
n := len(msgs) - 1
|
||||||
// in reverse, find all messages that fit into context window
|
// in reverse, find all messages that fit into context window
|
||||||
for i := n - 1; i >= 0; i-- {
|
for i := n - 1; i >= 0; i-- {
|
||||||
|
system = make([]api.Message, 0)
|
||||||
|
for j := range i {
|
||||||
|
if msgs[j].Role == "system" {
|
||||||
|
system = append(system, msgs[j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
|
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,7 +57,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
|
|
||||||
// truncate any messages that do not fit into the context window
|
// truncate any messages that do not fit into the context window
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
|
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -3,21 +3,13 @@ package server
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
func tokenize(_ context.Context, s string) (tokens []int, err error) {
|
|
||||||
for range strings.Fields(s) {
|
|
||||||
tokens = append(tokens, len(tokens))
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChatPrompt(t *testing.T) {
|
func TestChatPrompt(t *testing.T) {
|
||||||
type expect struct {
|
type expect struct {
|
||||||
prompt string
|
prompt string
|
||||||
@@ -160,6 +152,19 @@ func TestChatPrompt(t *testing.T) {
|
|||||||
{Role: "assistant", Content: "I-I'm a what?"},
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||||
},
|
},
|
||||||
|
expect: expect{
|
||||||
|
prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "out of order system",
|
||||||
|
limit: 2048,
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "You're a test, Harry!"},
|
||||||
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
|
{Role: "system", Content: "You are the Test Who Lived."},
|
||||||
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||||
|
},
|
||||||
expect: expect{
|
expect: expect{
|
||||||
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
|
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
|
||||||
},
|
},
|
||||||
@@ -178,13 +183,13 @@ func TestChatPrompt(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
|
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
|
||||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||||
prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs)
|
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tt.prompt != prompt {
|
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
|
||||||
t.Errorf("expected %q, got %q", tt.prompt, prompt)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(images) != len(tt.images) {
|
if len(images) != len(tt.images) {
|
||||||
|
261
server/routes.go
261
server/routes.go
@@ -9,6 +9,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -102,6 +103,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) GenerateHandler(c *gin.Context) {
|
func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
|
checkpointStart := time.Now()
|
||||||
var req api.GenerateRequest
|
var req api.GenerateRequest
|
||||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
@@ -120,6 +122,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
caps := []Capability{CapabilityCompletion}
|
caps := []Capability{CapabilityCompletion}
|
||||||
|
if req.Suffix != "" {
|
||||||
|
caps = append(caps, CapabilityInsert)
|
||||||
|
}
|
||||||
|
|
||||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||||
if errors.Is(err, errCapabilityCompletion) {
|
if errors.Is(err, errCapabilityCompletion) {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
||||||
@@ -129,6 +135,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
if req.Prompt == "" {
|
if req.Prompt == "" {
|
||||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
@@ -146,19 +154,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
|
|
||||||
prompt := req.Prompt
|
prompt := req.Prompt
|
||||||
if !req.Raw {
|
if !req.Raw {
|
||||||
var msgs []api.Message
|
|
||||||
if req.System != "" {
|
|
||||||
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
|
|
||||||
} else if m.System != "" {
|
|
||||||
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, i := range images {
|
|
||||||
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
|
|
||||||
}
|
|
||||||
|
|
||||||
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
|
||||||
|
|
||||||
tmpl := m.Template
|
tmpl := m.Template
|
||||||
if req.Template != "" {
|
if req.Template != "" {
|
||||||
tmpl, err = template.Parse(req.Template)
|
tmpl, err = template.Parse(req.Template)
|
||||||
@@ -179,7 +174,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
b.WriteString(s)
|
b.WriteString(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
|
var values template.Values
|
||||||
|
if req.Suffix != "" {
|
||||||
|
values.Prompt = prompt
|
||||||
|
values.Suffix = req.Suffix
|
||||||
|
} else {
|
||||||
|
var msgs []api.Message
|
||||||
|
if req.System != "" {
|
||||||
|
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
|
||||||
|
} else if m.System != "" {
|
||||||
|
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, i := range images {
|
||||||
|
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
|
||||||
|
}
|
||||||
|
|
||||||
|
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tmpl.Execute(&b, values); err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -191,26 +205,48 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
|
|
||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
go func() {
|
go func() {
|
||||||
|
// TODO (jmorganca): avoid building the response twice both here and below
|
||||||
|
var sb strings.Builder
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
}, func(r llm.CompletionResponse) {
|
}, func(cr llm.CompletionResponse) {
|
||||||
ch <- api.GenerateResponse{
|
res := api.GenerateResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
Response: r.Content,
|
Response: cr.Content,
|
||||||
Done: r.Done,
|
Done: cr.Done,
|
||||||
DoneReason: r.DoneReason,
|
DoneReason: cr.DoneReason,
|
||||||
Metrics: api.Metrics{
|
Metrics: api.Metrics{
|
||||||
PromptEvalCount: r.PromptEvalCount,
|
PromptEvalCount: cr.PromptEvalCount,
|
||||||
PromptEvalDuration: r.PromptEvalDuration,
|
PromptEvalDuration: cr.PromptEvalDuration,
|
||||||
EvalCount: r.EvalCount,
|
EvalCount: cr.EvalCount,
|
||||||
EvalDuration: r.EvalDuration,
|
EvalDuration: cr.EvalDuration,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if _, err := sb.WriteString(cr.Content); err != nil {
|
||||||
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cr.Done {
|
||||||
|
res.TotalDuration = time.Since(checkpointStart)
|
||||||
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
|
|
||||||
|
if !req.Raw {
|
||||||
|
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
|
||||||
|
if err != nil {
|
||||||
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res.Context = append(req.Context, tokens...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ch <- res
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
@@ -246,6 +282,121 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) EmbedHandler(c *gin.Context) {
|
||||||
|
var req api.EmbedRequest
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
truncate := true
|
||||||
|
|
||||||
|
if req.Truncate != nil && !*req.Truncate {
|
||||||
|
truncate = false
|
||||||
|
}
|
||||||
|
|
||||||
|
var input []string
|
||||||
|
|
||||||
|
switch i := req.Input.(type) {
|
||||||
|
case string:
|
||||||
|
if len(i) > 0 {
|
||||||
|
input = append(input, i)
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for _, v := range i {
|
||||||
|
if _, ok := v.(string); !ok {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
input = append(input, v.(string))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(input) == 0 {
|
||||||
|
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||||
|
if err != nil {
|
||||||
|
handleScheduleError(c, req.Model, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
kvData, err := getKVData(m.ModelPath, false)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, s := range input {
|
||||||
|
tokens, err := r.Tokenize(c.Request.Context(), s)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
||||||
|
if len(tokens) > ctxLen {
|
||||||
|
if !truncate {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens = tokens[:ctxLen]
|
||||||
|
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
input[i] = s
|
||||||
|
}
|
||||||
|
embeddings, err := r.Embed(c.Request.Context(), input)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("embedding generation failed", "error", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, e := range embeddings {
|
||||||
|
embeddings[i] = normalize(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := api.EmbedResponse{
|
||||||
|
Model: req.Model,
|
||||||
|
Embeddings: embeddings,
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalize(vec []float32) []float32 {
|
||||||
|
var sum float32
|
||||||
|
for _, v := range vec {
|
||||||
|
sum += v * v
|
||||||
|
}
|
||||||
|
|
||||||
|
norm := float32(0.0)
|
||||||
|
if sum > 0 {
|
||||||
|
norm = float32(1.0 / math.Sqrt(float64(sum)))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range vec {
|
||||||
|
vec[i] *= norm
|
||||||
|
}
|
||||||
|
return vec
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||||
var req api.EmbeddingRequest
|
var req api.EmbeddingRequest
|
||||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||||
@@ -268,14 +419,24 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
|
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding})
|
embedding := make([]float64, len(embeddings[0]))
|
||||||
|
|
||||||
|
for i, v := range embeddings[0] {
|
||||||
|
embedding[i] = float64(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := api.EmbeddingResponse{
|
||||||
|
Embedding: embedding,
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) PullModelHandler(c *gin.Context) {
|
func (s *Server) PullModelHandler(c *gin.Context) {
|
||||||
@@ -549,13 +710,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
m.System = req.System
|
m.System = req.System
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Template != "" {
|
|
||||||
m.Template, err = template.Parse(req.Template)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
msgs := make([]api.Message, len(m.Messages))
|
msgs := make([]api.Message, len(m.Messages))
|
||||||
for i, msg := range m.Messages {
|
for i, msg := range m.Messages {
|
||||||
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
|
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
|
||||||
@@ -901,6 +1055,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
|||||||
r.POST("/api/pull", s.PullModelHandler)
|
r.POST("/api/pull", s.PullModelHandler)
|
||||||
r.POST("/api/generate", s.GenerateHandler)
|
r.POST("/api/generate", s.GenerateHandler)
|
||||||
r.POST("/api/chat", s.ChatHandler)
|
r.POST("/api/chat", s.ChatHandler)
|
||||||
|
r.POST("/api/embed", s.EmbedHandler)
|
||||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||||
r.POST("/api/create", s.CreateModelHandler)
|
r.POST("/api/create", s.CreateModelHandler)
|
||||||
r.POST("/api/push", s.PushModelHandler)
|
r.POST("/api/push", s.PushModelHandler)
|
||||||
@@ -914,6 +1069,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
|||||||
// Compatibility endpoints
|
// Compatibility endpoints
|
||||||
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
||||||
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
|
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
|
||||||
|
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||||
r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
|
r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
|
||||||
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
|
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
|
||||||
|
|
||||||
@@ -1122,6 +1278,8 @@ func (s *Server) ProcessHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ChatHandler(c *gin.Context) {
|
func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
|
checkpointStart := time.Now()
|
||||||
|
|
||||||
var req api.ChatRequest
|
var req api.ChatRequest
|
||||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
@@ -1132,6 +1290,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
caps := []Capability{CapabilityCompletion}
|
caps := []Capability{CapabilityCompletion}
|
||||||
|
if req.Tools != nil {
|
||||||
|
caps = append(caps, CapabilityTools)
|
||||||
|
}
|
||||||
|
|
||||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||||
if errors.Is(err, errCapabilityCompletion) {
|
if errors.Is(err, errCapabilityCompletion) {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
||||||
@@ -1141,6 +1303,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
if len(req.Messages) == 0 {
|
if len(req.Messages) == 0 {
|
||||||
c.JSON(http.StatusOK, api.ChatResponse{
|
c.JSON(http.StatusOK, api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
@@ -1152,7 +1316,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
|
if req.Messages[0].Role != "system" && m.System != "" {
|
||||||
|
req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...)
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -1169,7 +1337,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
}, func(r llm.CompletionResponse) {
|
}, func(r llm.CompletionResponse) {
|
||||||
ch <- api.ChatResponse{
|
res := api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||||
@@ -1182,19 +1350,26 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
EvalDuration: r.EvalDuration,
|
EvalDuration: r.EvalDuration,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.Done {
|
||||||
|
res.TotalDuration = time.Since(checkpointStart)
|
||||||
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
|
}
|
||||||
|
|
||||||
|
ch <- res
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if req.Stream != nil && !*req.Stream {
|
if req.Stream != nil && !*req.Stream {
|
||||||
var r api.ChatResponse
|
var resp api.ChatResponse
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for rr := range ch {
|
for rr := range ch {
|
||||||
switch t := rr.(type) {
|
switch t := rr.(type) {
|
||||||
case api.ChatResponse:
|
case api.ChatResponse:
|
||||||
sb.WriteString(t.Message.Content)
|
sb.WriteString(t.Message.Content)
|
||||||
r = t
|
resp = t
|
||||||
case gin.H:
|
case gin.H:
|
||||||
msg, ok := t["error"].(string)
|
msg, ok := t["error"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -1209,8 +1384,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Message.Content = sb.String()
|
resp.Message.Content = sb.String()
|
||||||
c.JSON(http.StatusOK, r)
|
|
||||||
|
if len(req.Tools) > 0 {
|
||||||
|
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||||
|
resp.Message.ToolCalls = toolCalls
|
||||||
|
resp.Message.Content = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, resp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1219,7 +1402,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
|
|
||||||
func handleScheduleError(c *gin.Context, name string, err error) {
|
func handleScheduleError(c *gin.Context, name string, err error) {
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, errRequired):
|
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
case errors.Is(err, context.Canceled):
|
case errors.Is(err, context.Canceled):
|
||||||
c.JSON(499, gin.H{"error": "request canceled"})
|
c.JSON(499, gin.H{"error": "request canceled"})
|
||||||
|
@@ -85,6 +85,8 @@ func checkFileExists(t *testing.T, p string, expect []string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateFromBin(t *testing.T) {
|
func TestCreateFromBin(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
@@ -111,6 +113,8 @@ func TestCreateFromBin(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateFromModel(t *testing.T) {
|
func TestCreateFromModel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
@@ -152,6 +156,8 @@ func TestCreateFromModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateRemovesLayers(t *testing.T) {
|
func TestCreateRemovesLayers(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
@@ -199,6 +205,8 @@ func TestCreateRemovesLayers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateUnsetsSystem(t *testing.T) {
|
func TestCreateUnsetsSystem(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
@@ -255,6 +263,8 @@ func TestCreateUnsetsSystem(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateMergeParameters(t *testing.T) {
|
func TestCreateMergeParameters(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
@@ -358,6 +368,8 @@ func TestCreateMergeParameters(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateReplacesMessages(t *testing.T) {
|
func TestCreateReplacesMessages(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
@@ -434,6 +446,8 @@ func TestCreateReplacesMessages(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateTemplateSystem(t *testing.T) {
|
func TestCreateTemplateSystem(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
@@ -480,6 +494,8 @@ func TestCreateTemplateSystem(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateLicenses(t *testing.T) {
|
func TestCreateLicenses(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
@@ -526,6 +542,8 @@ func TestCreateLicenses(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateDetectTemplate(t *testing.T) {
|
func TestCreateDetectTemplate(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
@@ -546,8 +564,8 @@ func TestCreateDetectTemplate(t *testing.T) {
|
|||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
|
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
|
||||||
filepath.Join(p, "blobs", "sha256-9512c372dfc7d84d6065b8dd2b601aeed8cc1a78e7a7aa784a42fff37f5524b7"),
|
filepath.Join(p, "blobs", "sha256-c608dc615584cd20d9d830363dabf8a4783ae5d34245c3d8c115edb3bc7b28e4"),
|
||||||
filepath.Join(p, "blobs", "sha256-b8b78cb8c6eefd14c06f1af042e6161255bf87bbf2dd14fce57cdac893db8139"),
|
filepath.Join(p, "blobs", "sha256-f836ee110db21567f826332e4cedd746c06d10664fd5a9ea3659e3683a944510"),
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@@ -8,12 +8,15 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDelete(t *testing.T) {
|
func TestDelete(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
@@ -77,6 +80,8 @@ func TestDelete(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteDuplicateLayers(t *testing.T) {
|
func TestDeleteDuplicateLayers(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
var s Server
|
var s Server
|
||||||
|
712
server/routes_generate_test.go
Normal file
712
server/routes_generate_test.go
Normal file
@@ -0,0 +1,712 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/gpu"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockRunner struct {
|
||||||
|
llm.LlamaServer
|
||||||
|
|
||||||
|
// CompletionRequest is only valid until the next call to Completion
|
||||||
|
llm.CompletionRequest
|
||||||
|
llm.CompletionResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||||
|
m.CompletionRequest = r
|
||||||
|
fn(m.CompletionResponse)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
|
||||||
|
for range strings.Fields(s) {
|
||||||
|
tokens = append(tokens, len(tokens))
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockServer(mock *mockRunner) func(gpu.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
|
||||||
|
return func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||||
|
return mock, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateChat(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
mock := mockRunner{
|
||||||
|
CompletionResponse: llm.CompletionResponse{
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
PromptEvalCount: 1,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
EvalCount: 1,
|
||||||
|
EvalDuration: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := Server{
|
||||||
|
sched: &Scheduler{
|
||||||
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
expiredCh: make(chan *runnerRef, 1),
|
||||||
|
unloadedCh: make(chan any, 1),
|
||||||
|
loaded: make(map[string]*runnerRef),
|
||||||
|
newServerFn: newMockServer(&mock),
|
||||||
|
getGpuFn: gpu.GetGPUInfo,
|
||||||
|
getCpuFn: gpu.GetCPUInfo,
|
||||||
|
reschedDelay: 250 * time.Millisecond,
|
||||||
|
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
||||||
|
// add 10ms delay to simulate loading
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
req.successCh <- &runnerRef{
|
||||||
|
llama: &mock,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.sched.Run(context.TODO())
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Modelfile: fmt.Sprintf(`FROM %s
|
||||||
|
TEMPLATE """
|
||||||
|
{{- if .System }}System: {{ .System }} {{ end }}
|
||||||
|
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
|
||||||
|
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
|
||||||
|
`, createBinFile(t, llm.KV{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"llama.block_count": uint32(1),
|
||||||
|
"llama.context_length": uint32(8192),
|
||||||
|
"llama.embedding_length": uint32(4096),
|
||||||
|
"llama.attention.head_count": uint32(32),
|
||||||
|
"llama.attention.head_count_kv": uint32(8),
|
||||||
|
"tokenizer.ggml.tokens": []string{""},
|
||||||
|
"tokenizer.ggml.scores": []float32{0},
|
||||||
|
"tokenizer.ggml.token_type": []int32{0},
|
||||||
|
}, []llm.Tensor{
|
||||||
|
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
})),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("missing body", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, nil)
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing model", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing capabilities chat", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "bert",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
|
||||||
|
"general.architecture": "bert",
|
||||||
|
"bert.pooling_type": uint32(0),
|
||||||
|
}, []llm.Tensor{})),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "bert",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("load model", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var actual api.ChatResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Model != "test" {
|
||||||
|
t.Errorf("expected model test, got %s", actual.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !actual.Done {
|
||||||
|
t.Errorf("expected done true, got false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.DoneReason != "load" {
|
||||||
|
t.Errorf("expected done reason load, got %s", actual.DoneReason)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var actual api.ChatResponse
|
||||||
|
if err := json.NewDecoder(body).Decode(&actual); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Model != model {
|
||||||
|
t.Errorf("expected model test, got %s", actual.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !actual.Done {
|
||||||
|
t.Errorf("expected done false, got true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.DoneReason != "stop" {
|
||||||
|
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(actual.Message, api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: content,
|
||||||
|
}); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.PromptEvalCount == 0 {
|
||||||
|
t.Errorf("expected prompt eval count > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.PromptEvalDuration == 0 {
|
||||||
|
t.Errorf("expected prompt eval duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.EvalCount == 0 {
|
||||||
|
t.Errorf("expected eval count > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.EvalDuration == 0 {
|
||||||
|
t.Errorf("expected eval duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.LoadDuration == 0 {
|
||||||
|
t.Errorf("expected load duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.TotalDuration == 0 {
|
||||||
|
t.Errorf("expected total duration > 0, got 0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock.CompletionResponse.Content = "Hi!"
|
||||||
|
t.Run("messages", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkChatResponse(t, w.Body, "test", "Hi!")
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("messages with model system", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkChatResponse(t, w.Body, "test-system", "Hi!")
|
||||||
|
})
|
||||||
|
|
||||||
|
mock.CompletionResponse.Content = "Abra kadabra!"
|
||||||
|
t.Run("messages with system", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "system", Content: "You can perform magic tricks."},
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("messages with interleaved system", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
{Role: "assistant", Content: "I can help you with that."},
|
||||||
|
{Role: "system", Content: "You can perform magic tricks."},
|
||||||
|
{Role: "user", Content: "Help me write tests."},
|
||||||
|
},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerate(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
mock := mockRunner{
|
||||||
|
CompletionResponse: llm.CompletionResponse{
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
PromptEvalCount: 1,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
EvalCount: 1,
|
||||||
|
EvalDuration: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := Server{
|
||||||
|
sched: &Scheduler{
|
||||||
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
expiredCh: make(chan *runnerRef, 1),
|
||||||
|
unloadedCh: make(chan any, 1),
|
||||||
|
loaded: make(map[string]*runnerRef),
|
||||||
|
newServerFn: newMockServer(&mock),
|
||||||
|
getGpuFn: gpu.GetGPUInfo,
|
||||||
|
getCpuFn: gpu.GetCPUInfo,
|
||||||
|
reschedDelay: 250 * time.Millisecond,
|
||||||
|
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
||||||
|
req.successCh <- &runnerRef{
|
||||||
|
llama: &mock,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.sched.Run(context.TODO())
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Modelfile: fmt.Sprintf(`FROM %s
|
||||||
|
TEMPLATE """
|
||||||
|
{{- if .System }}System: {{ .System }} {{ end }}
|
||||||
|
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
|
||||||
|
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
|
||||||
|
`, createBinFile(t, llm.KV{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"llama.block_count": uint32(1),
|
||||||
|
"llama.context_length": uint32(8192),
|
||||||
|
"llama.embedding_length": uint32(4096),
|
||||||
|
"llama.attention.head_count": uint32(32),
|
||||||
|
"llama.attention.head_count_kv": uint32(8),
|
||||||
|
"tokenizer.ggml.tokens": []string{""},
|
||||||
|
"tokenizer.ggml.scores": []float32{0},
|
||||||
|
"tokenizer.ggml.token_type": []int32{0},
|
||||||
|
}, []llm.Tensor{
|
||||||
|
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
})),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("missing body", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, nil)
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing model", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing capabilities generate", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "bert",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
|
||||||
|
"general.architecture": "bert",
|
||||||
|
"bert.pooling_type": uint32(0),
|
||||||
|
}, []llm.Tensor{})),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "bert",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing capabilities suffix", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Prompt: "def add(",
|
||||||
|
Suffix: " return c",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("load model", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var actual api.GenerateResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Model != "test" {
|
||||||
|
t.Errorf("expected model test, got %s", actual.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !actual.Done {
|
||||||
|
t.Errorf("expected done true, got false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.DoneReason != "load" {
|
||||||
|
t.Errorf("expected done reason load, got %s", actual.DoneReason)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var actual api.GenerateResponse
|
||||||
|
if err := json.NewDecoder(body).Decode(&actual); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Model != model {
|
||||||
|
t.Errorf("expected model test, got %s", actual.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !actual.Done {
|
||||||
|
t.Errorf("expected done false, got true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.DoneReason != "stop" {
|
||||||
|
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Response != content {
|
||||||
|
t.Errorf("expected response %s, got %s", content, actual.Response)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Context == nil {
|
||||||
|
t.Errorf("expected context not nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.PromptEvalCount == 0 {
|
||||||
|
t.Errorf("expected prompt eval count > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.PromptEvalDuration == 0 {
|
||||||
|
t.Errorf("expected prompt eval duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.EvalCount == 0 {
|
||||||
|
t.Errorf("expected eval count > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.EvalDuration == 0 {
|
||||||
|
t.Errorf("expected eval duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.LoadDuration == 0 {
|
||||||
|
t.Errorf("expected load duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.TotalDuration == 0 {
|
||||||
|
t.Errorf("expected total duration > 0, got 0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock.CompletionResponse.Content = "Hi!"
|
||||||
|
t.Run("prompt", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Prompt: "Hello!",
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkGenerateResponse(t, w.Body, "test", "Hi!")
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("prompt with model system", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Prompt: "Hello!",
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkGenerateResponse(t, w.Body, "test-system", "Hi!")
|
||||||
|
})
|
||||||
|
|
||||||
|
mock.CompletionResponse.Content = "Abra kadabra!"
|
||||||
|
t.Run("prompt with system", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Prompt: "Hello!",
|
||||||
|
System: "You can perform magic tricks.",
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("prompt with template", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Prompt: "Help me write tests.",
|
||||||
|
System: "You can perform magic tricks.",
|
||||||
|
Template: `{{- if .System }}{{ .System }} {{ end }}
|
||||||
|
{{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
|
||||||
|
{{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "test-suffix",
|
||||||
|
Modelfile: `FROM test
|
||||||
|
TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
|
||||||
|
{{- else }}{{ .Prompt }}
|
||||||
|
{{- end }}"""`,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("prompt with suffix", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-suffix",
|
||||||
|
Prompt: "def add(",
|
||||||
|
Suffix: " return c",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("prompt without suffix", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-suffix",
|
||||||
|
Prompt: "def add(",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("raw", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Prompt: "Help me write tests.",
|
||||||
|
Raw: true,
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
@@ -7,11 +7,14 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestList(t *testing.T) {
|
func TestList(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
|
||||||
|
@@ -7,6 +7,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
@@ -272,6 +273,77 @@ func Test_Routes(t *testing.T) {
|
|||||||
assert.Equal(t, "library", retrieveResp.OwnedBy)
|
assert.Equal(t, "library", retrieveResp.OwnedBy)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "Embed Handler Empty Input",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/embed",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
embedReq := api.EmbedRequest{
|
||||||
|
Model: "t-bone",
|
||||||
|
Input: "",
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(embedReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
if contentType != "application/json; charset=utf-8" {
|
||||||
|
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
|
||||||
|
}
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var embedResp api.EmbedResponse
|
||||||
|
err = json.Unmarshal(body, &embedResp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if embedResp.Model != "t-bone" {
|
||||||
|
t.Fatalf("expected model t-bone, got %s", embedResp.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if embedResp.Embeddings == nil {
|
||||||
|
t.Fatalf("expected embeddings to not be nil, got %v", embedResp.Embeddings)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(embedResp.Embeddings) != 0 {
|
||||||
|
t.Fatalf("expected embeddings to be empty, got %v", embedResp.Embeddings)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Embed Handler Invalid Input",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/embed",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
embedReq := api.EmbedRequest{
|
||||||
|
Model: "t-bone",
|
||||||
|
Input: 2,
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(embedReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
if contentType != "application/json; charset=utf-8" {
|
||||||
|
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
|
||||||
|
}
|
||||||
|
_, err := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status code 400, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
@@ -420,3 +492,38 @@ func TestShow(t *testing.T) {
|
|||||||
t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
|
t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNormalize(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
input []float32
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{input: []float32{1}},
|
||||||
|
{input: []float32{0, 1, 2, 3}},
|
||||||
|
{input: []float32{0.1, 0.2, 0.3}},
|
||||||
|
{input: []float32{-0.1, 0.2, 0.3, -0.4}},
|
||||||
|
{input: []float32{0, 0, 0}},
|
||||||
|
}
|
||||||
|
|
||||||
|
isNormalized := func(vec []float32) (res bool) {
|
||||||
|
sum := 0.0
|
||||||
|
for _, v := range vec {
|
||||||
|
sum += float64(v * v)
|
||||||
|
}
|
||||||
|
if math.Abs(sum-1) > 1e-6 {
|
||||||
|
return sum == 0
|
||||||
|
} else {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
normalized := normalize(tc.input)
|
||||||
|
if !isNormalized(normalized) {
|
||||||
|
t.Errorf("Vector %v is not normalized", tc.input)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -135,11 +135,6 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
cpus := s.getCpuFn()
|
|
||||||
var systemMem gpu.GpuInfo
|
|
||||||
if len(cpus) > 0 {
|
|
||||||
systemMem = cpus[0]
|
|
||||||
}
|
|
||||||
var runnerToExpire *runnerRef
|
var runnerToExpire *runnerRef
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
runner := s.loaded[pending.model.ModelPath]
|
runner := s.loaded[pending.model.ModelPath]
|
||||||
@@ -193,38 +188,6 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
estimate := llm.EstimateGPULayers(gpus, ggml, pending.model.ProjectorPaths, pending.opts)
|
|
||||||
maxSize := systemMem.FreeMemory
|
|
||||||
|
|
||||||
// Add available GPU memory to the total pool
|
|
||||||
// macOS hardware has unified memory so don't double count
|
|
||||||
if runtime.GOOS != "darwin" {
|
|
||||||
for _, gpu := range gpus {
|
|
||||||
if gpu.Library == "cpu" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if loadedCount == 0 {
|
|
||||||
// If no other models are loaded, set the limit based on what's available
|
|
||||||
maxSize += gpu.FreeMemory
|
|
||||||
} else {
|
|
||||||
// Other models could be unloaded, favor total memory for limit
|
|
||||||
maxSize += gpu.TotalMemory
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Block attempting to load a model larger than system memory + GPU memory
|
|
||||||
if estimate.TotalSize > maxSize {
|
|
||||||
slog.Warn("model request too large for system", "requested", format.HumanBytes2(estimate.TotalSize), "system", format.HumanBytes2(maxSize))
|
|
||||||
|
|
||||||
// Linux will crash if over-allocating memory - return an error to the user.
|
|
||||||
// TODO (jmorganca): add reasonable upper limits for darwin and windows as well
|
|
||||||
if runtime.GOOS == "linux" {
|
|
||||||
pending.errCh <- fmt.Errorf("requested model (%s) is too large for this system (%s)", format.HumanBytes2(estimate.TotalSize), format.HumanBytes2(maxSize))
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Evaluate if the model will fit in the available system memory, or if we should unload a model first
|
// Evaluate if the model will fit in the available system memory, or if we should unload a model first
|
||||||
if len(gpus) == 1 && gpus[0].Library == "cpu" {
|
if len(gpus) == 1 && gpus[0].Library == "cpu" {
|
||||||
// simplifying assumption of defaultParallel when in CPU mode
|
// simplifying assumption of defaultParallel when in CPU mode
|
||||||
|
@@ -642,8 +642,8 @@ type mockLlm struct {
|
|||||||
pingResp error
|
pingResp error
|
||||||
waitResp error
|
waitResp error
|
||||||
completionResp error
|
completionResp error
|
||||||
embeddingResp []float64
|
embedResp [][]float32
|
||||||
embeddingRespErr error
|
embedRespErr error
|
||||||
tokenizeResp []int
|
tokenizeResp []int
|
||||||
tokenizeRespErr error
|
tokenizeRespErr error
|
||||||
detokenizeResp string
|
detokenizeResp string
|
||||||
@@ -660,8 +660,8 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
|
|||||||
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||||
return s.completionResp
|
return s.completionResp
|
||||||
}
|
}
|
||||||
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) {
|
||||||
return s.embeddingResp, s.embeddingRespErr
|
return s.embedResp, s.embedRespErr
|
||||||
}
|
}
|
||||||
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||||
return s.tokenizeResp, s.tokenizeRespErr
|
return s.tokenizeResp, s.tokenizeRespErr
|
||||||
|
67
server/testdata/tools/command-r-plus.gotmpl
vendored
Normal file
67
server/testdata/tools/command-r-plus.gotmpl
vendored
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
{{- if or .Tools .System }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>
|
||||||
|
{{- if .Tools }}# Safety Preamble
|
||||||
|
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
|
||||||
|
|
||||||
|
# System Preamble
|
||||||
|
## Basic Rules
|
||||||
|
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
|
||||||
|
|
||||||
|
{{ if .System }}# User Preamble
|
||||||
|
{{ .System }}
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
|
## Available Tools
|
||||||
|
Here is a list of tools that you have available to you:
|
||||||
|
{{- range .Tools }}
|
||||||
|
|
||||||
|
```python
|
||||||
|
def {{ .Function.Name }}(
|
||||||
|
{{- range $name, $property := .Function.Parameters.Properties }}{{ $name }}: {{ $property.Type }}, {{ end }}) -> List[Dict]:
|
||||||
|
"""{{ .Function.Description }}
|
||||||
|
|
||||||
|
{{- if .Function.Parameters.Properties }}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
{{- range $name, $property := .Function.Parameters.Properties }}
|
||||||
|
{{ $name }} ({{ $property.Type }}): {{ $property.Description }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
{{- end }}
|
||||||
|
{{- else if .System }}{{ .System }}
|
||||||
|
{{- end }}<|END_OF_TURN_TOKEN|>
|
||||||
|
{{- end }}
|
||||||
|
{{- range .Messages }}
|
||||||
|
{{- if eq .Role "system" }}
|
||||||
|
{{- continue }}
|
||||||
|
{{- end }}<|START_OF_TURN_TOKEN|>
|
||||||
|
{{- if eq .Role "user" }}<|USER_TOKEN|>{{ .Content }}
|
||||||
|
{{- else if eq .Role "assistant" }}<|CHATBOT_TOKEN|>
|
||||||
|
{{- if .Content }}{{ .Content }}
|
||||||
|
{{- else if .ToolCalls }}
|
||||||
|
Action: ```json
|
||||||
|
[
|
||||||
|
{{- range .ToolCalls }}
|
||||||
|
{
|
||||||
|
"tool_name": "{{ .Function.Name }}",
|
||||||
|
"parameters": {{ .Function.Arguments }}
|
||||||
|
}
|
||||||
|
{{- end }}
|
||||||
|
]```
|
||||||
|
{{ continue }}
|
||||||
|
{{ end }}
|
||||||
|
{{- else if eq .Role "tool" }}<|SYSTEM_TOKEN|><results>
|
||||||
|
{{ .Content }}</results>
|
||||||
|
{{- end }}<|END_OF_TURN_TOKEN|>
|
||||||
|
{{- end }}
|
||||||
|
{{- if .Tools }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tool_name": title of the tool in the specification,
|
||||||
|
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
|
||||||
|
}
|
||||||
|
]```
|
||||||
|
{{- end }}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
39
server/testdata/tools/command-r-plus.out
vendored
Normal file
39
server/testdata/tools/command-r-plus.out
vendored
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble
|
||||||
|
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
|
||||||
|
|
||||||
|
# System Preamble
|
||||||
|
## Basic Rules
|
||||||
|
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
|
||||||
|
|
||||||
|
# User Preamble
|
||||||
|
You are a knowledgable assistant. You can answer questions and perform tasks.
|
||||||
|
|
||||||
|
## Available Tools
|
||||||
|
Here is a list of tools that you have available to you:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_current_weather(format: string, location: string, ) -> List[Dict]:
|
||||||
|
"""Get the current weather
|
||||||
|
|
||||||
|
Args:
|
||||||
|
format (string): The temperature unit to use. Infer this from the users location.
|
||||||
|
location (string): The city and state, e.g. San Francisco, CA
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in Paris?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
||||||
|
Action: ```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tool_name": "get_current_weather",
|
||||||
|
"parameters": {"format":"celsius","location":"Paris, France"}
|
||||||
|
}
|
||||||
|
]```
|
||||||
|
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><results>
|
||||||
|
22</results><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>The current temperature in Paris, France is 22 degrees Celsius.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in San Francisco and Toronto?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tool_name": title of the tool in the specification,
|
||||||
|
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
|
||||||
|
}
|
||||||
|
]```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
31
server/testdata/tools/firefunction.gotmpl
vendored
Normal file
31
server/testdata/tools/firefunction.gotmpl
vendored
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
|
||||||
|
{{- if .System }}
|
||||||
|
{{ .System }}
|
||||||
|
{{- end }}
|
||||||
|
In addition to plain text responses, you can chose to call one or more of the provided functions.
|
||||||
|
|
||||||
|
Use the following rule to decide when to call a function:
|
||||||
|
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
|
||||||
|
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
|
||||||
|
|
||||||
|
If you decide to call functions:
|
||||||
|
* prefix function calls with functools marker (no closing marker required)
|
||||||
|
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
|
||||||
|
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
|
||||||
|
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
|
||||||
|
* make sure you pick the right functions that match the user intent
|
||||||
|
|
||||||
|
Available functions as JSON spec:
|
||||||
|
{{- if .Tools }}
|
||||||
|
{{ .Tools }}
|
||||||
|
{{- end }}<|eot_id|>
|
||||||
|
{{- end }}
|
||||||
|
{{- range .Messages }}<|start_header_id|>
|
||||||
|
{{- if or (eq .Role "user") (eq .Role "assistant") (eq .Role "tool") }}{{ .Role }}
|
||||||
|
{{- end }}<|end_header_id|>
|
||||||
|
{{- if .Content }}{{ .Content }}
|
||||||
|
{{- else if .ToolCalls }} functools[
|
||||||
|
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }}
|
||||||
|
{{- end }}]
|
||||||
|
{{- end }}<|eot_id|>
|
||||||
|
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
17
server/testdata/tools/firefunction.out
vendored
Normal file
17
server/testdata/tools/firefunction.out
vendored
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
<|start_header_id|>system<|end_header_id|>
|
||||||
|
You are a knowledgable assistant. You can answer questions and perform tasks.
|
||||||
|
In addition to plain text responses, you can chose to call one or more of the provided functions.
|
||||||
|
|
||||||
|
Use the following rule to decide when to call a function:
|
||||||
|
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
|
||||||
|
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
|
||||||
|
|
||||||
|
If you decide to call functions:
|
||||||
|
* prefix function calls with functools marker (no closing marker required)
|
||||||
|
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
|
||||||
|
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
|
||||||
|
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
|
||||||
|
* make sure you pick the right functions that match the user intent
|
||||||
|
|
||||||
|
Available functions as JSON spec:
|
||||||
|
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]<|eot_id|><|start_header_id|><|end_header_id|>You are a knowledgable assistant. You can answer questions and perform tasks.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> functools[{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]<|eot_id|><|start_header_id|>tool<|end_header_id|>22<|eot_id|><|start_header_id|>assistant<|end_header_id|>The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
43
server/testdata/tools/llama3-groq-tool-use.gotmpl
vendored
Normal file
43
server/testdata/tools/llama3-groq-tool-use.gotmpl
vendored
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
{{- if .Messages }}
|
||||||
|
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .System }}
|
||||||
|
{{- if .Tools }} You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||||
|
<tool_call>
|
||||||
|
{"name": <function-name>,"arguments": <args-dict>}
|
||||||
|
</tool_call>
|
||||||
|
|
||||||
|
Here are the available tools:
|
||||||
|
<tools>
|
||||||
|
{{- range .Tools }} {{ .Function }}
|
||||||
|
{{- end }} </tools>
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}<|eot_id|>
|
||||||
|
{{- range .Messages }}
|
||||||
|
{{- if ne .Role "system" }}<|start_header_id|>{{ .Role }}<|end_header_id|>
|
||||||
|
|
||||||
|
{{ if eq .Role "user" }}{{ .Content }}
|
||||||
|
{{- else if eq .Role "assistant" }}
|
||||||
|
{{- if .Content }}{{ .Content }}
|
||||||
|
{{- else if .ToolCalls }}<tool_call>
|
||||||
|
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||||
|
{{- end }}
|
||||||
|
</tool_call>
|
||||||
|
{{- end }}
|
||||||
|
{{- else if eq .Role "tool" }}<tool_response>
|
||||||
|
{{ .Content }}
|
||||||
|
</tool_response>
|
||||||
|
{{- end }}<|eot_id|>
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
{{ else }}
|
||||||
|
{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
{{ end }}{{ .Response }}
|
||||||
|
{{- if .Response }}<|eot_id|>
|
||||||
|
{{- end }}
|
24
server/testdata/tools/llama3-groq-tool-use.out
vendored
Normal file
24
server/testdata/tools/llama3-groq-tool-use.out
vendored
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
You are a knowledgable assistant. You can answer questions and perform tasks. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||||
|
<tool_call>
|
||||||
|
{"name": <function-name>,"arguments": <args-dict>}
|
||||||
|
</tool_call>
|
||||||
|
|
||||||
|
Here are the available tools:
|
||||||
|
<tools> {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}} </tools><|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
<tool_call>
|
||||||
|
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
|
||||||
|
</tool_call><|eot_id|><|start_header_id|>tool<|end_header_id|>
|
||||||
|
|
||||||
|
<tool_response>
|
||||||
|
22
|
||||||
|
</tool_response><|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
39
server/testdata/tools/messages.json
vendored
Normal file
39
server/testdata/tools/messages.json
vendored
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a knowledgable assistant. You can answer questions and perform tasks."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather like today in Paris?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "89a1e453-0bce-4de3-a456-c54bed09c520",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"arguments": {
|
||||||
|
"location": "Paris, France",
|
||||||
|
"format": "celsius"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "89a1e453-0bce-4de3-a456-c54bed09c520",
|
||||||
|
"content": "22"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "The current temperature in Paris, France is 22 degrees Celsius."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather like today in San Francisco and Toronto?"
|
||||||
|
}
|
||||||
|
]
|
15
server/testdata/tools/mistral.gotmpl
vendored
Normal file
15
server/testdata/tools/mistral.gotmpl
vendored
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
{{- range $index, $_ := .Messages }}
|
||||||
|
{{- if eq .Role "user" }}
|
||||||
|
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]
|
||||||
|
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
|
||||||
|
|
||||||
|
{{ end }}{{ .Content }}[/INST]
|
||||||
|
{{- else if eq .Role "assistant" }}
|
||||||
|
{{- if .Content }} {{ .Content }}</s>
|
||||||
|
{{- else if .ToolCalls }}[TOOL_CALLS] [
|
||||||
|
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||||
|
{{- end }}]</s>
|
||||||
|
{{- end }}
|
||||||
|
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
3
server/testdata/tools/mistral.out
vendored
Normal file
3
server/testdata/tools/mistral.out
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[INST] What's the weather like today in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]</s>[TOOL_RESULTS] {"content": 22}[/TOOL_RESULTS] The current temperature in Paris, France is 22 degrees Celsius.</s>[AVAILABLE_TOOLS] [{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}][/AVAILABLE_TOOLS][INST] You are a knowledgable assistant. You can answer questions and perform tasks.
|
||||||
|
|
||||||
|
What's the weather like today in San Francisco and Toronto?[/INST]
|
30
server/testdata/tools/tools.json
vendored
Normal file
30
server/testdata/tools/tools.json
vendored
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA"
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"celsius",
|
||||||
|
"fahrenheit"
|
||||||
|
],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"location",
|
||||||
|
"format"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
@@ -1,8 +1 @@
|
|||||||
{{- if .Messages }}
|
{{ if .System }}<start_system>{{ .System }}<end_message>{{ end }}{{ if .Prompt }}<start_user>{{ .Prompt }}<end_message>{{ end }}<start_assistant>{{ .Response }}<end_message>
|
||||||
{{- if .System }}<start_system>{{ .System }}<end_message>
|
|
||||||
{{- end }}
|
|
||||||
{{- range .Messages }}<start_{{ .Role }}>{{ .Content }}<end_message>
|
|
||||||
{{- end }}<start_assistant>
|
|
||||||
{{- else }}
|
|
||||||
{{ if .System }}<start_system>{{ .System }}<end_message>{{ end }}{{ if .Prompt }}<start_user>{{ .Prompt }}<end_message>{{ end }}<start_assistant>{{ .Response }}<end_message>
|
|
||||||
{{- end }}
|
|
@@ -1,14 +1,3 @@
|
|||||||
{{- if .Messages }}
|
|
||||||
{{- if .System }}{{ .System }}
|
|
||||||
{{- end }}
|
|
||||||
{{- range .Messages }}
|
|
||||||
{{- if eq .Role "user" }}### Instruction:
|
|
||||||
{{- else if eq .Role "assistant" }}### Response:
|
|
||||||
{{- end }}
|
|
||||||
{{ .Content }}
|
|
||||||
|
|
||||||
{{ end }}### Response:
|
|
||||||
{{ else }}
|
|
||||||
{{ if .System }}{{ .System }}
|
{{ if .System }}{{ .System }}
|
||||||
|
|
||||||
{{ end }}{{ if .Prompt }}### Instruction:
|
{{ end }}{{ if .Prompt }}### Instruction:
|
||||||
@@ -16,4 +5,4 @@
|
|||||||
|
|
||||||
{{ end }}### Response:
|
{{ end }}### Response:
|
||||||
{{ .Response }}
|
{{ .Response }}
|
||||||
{{- end }}
|
|
||||||
|
@@ -1,15 +1,6 @@
|
|||||||
{{- if .Messages }}
|
|
||||||
{{- if .System }}<|im_start|>system
|
|
||||||
{{ .System }}<|im_end|>
|
|
||||||
{{ end }}
|
|
||||||
{{- range .Messages }}<|im_start|>{{ .Role }}
|
|
||||||
{{ .Content }}<|im_end|>
|
|
||||||
{{ end }}<|im_start|>assistant
|
|
||||||
{{ else }}
|
|
||||||
{{ if .System }}<|im_start|>system
|
{{ if .System }}<|im_start|>system
|
||||||
{{ .System }}<|im_end|>
|
{{ .System }}<|im_end|>
|
||||||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||||
{{ .Prompt }}<|im_end|>
|
{{ .Prompt }}<|im_end|>
|
||||||
{{ end }}<|im_start|>assistant
|
{{ end }}<|im_start|>assistant
|
||||||
{{ .Response }}<|im_end|>
|
{{ .Response }}<|im_end|>
|
||||||
{{- end }}
|
|
@@ -1,17 +1,6 @@
|
|||||||
{{- if .Messages }}
|
|
||||||
{{- if .System }}System: {{ .System }}
|
|
||||||
|
|
||||||
{{ end }}
|
|
||||||
{{- range .Messages }}
|
|
||||||
{{- if eq .Role "user" }}User:
|
|
||||||
{{- else if eq .Role "assistant" }}Assistant:
|
|
||||||
{{- end }} {{ .Content }}
|
|
||||||
|
|
||||||
{{ end }}Assistant:
|
|
||||||
{{- else }}
|
|
||||||
{{ if .System }}System: {{ .System }}
|
{{ if .System }}System: {{ .System }}
|
||||||
|
|
||||||
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
|
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
|
||||||
|
|
||||||
{{ end }}Assistant: <|begin_of_text|>{{ .Response }}
|
{{ end }}Assistant: {{ .Response }}
|
||||||
{{- end }}
|
|
||||||
|
@@ -1,19 +1,10 @@
|
|||||||
{{- if .Messages }}
|
{{ if .System }}Source: system
|
||||||
{{- if .System }}Source: system
|
|
||||||
|
|
||||||
{{ .System }} <step> {{ end }}
|
{{ .System }} <step> {{ end }}Source: user
|
||||||
{{- range .Messages }}Source: {{ .Role }}
|
|
||||||
|
|
||||||
{{ .Content }} <step> {{ end }}Source: assistant
|
|
||||||
Destination: user
|
|
||||||
|
|
||||||
{{ else }}
|
|
||||||
{{ if .System }} Source: system
|
|
||||||
|
|
||||||
{{ .System }} <step>{{ end }} Source: user
|
|
||||||
|
|
||||||
{{ .Prompt }} <step> Source: assistant
|
{{ .Prompt }} <step> Source: assistant
|
||||||
|
{{- if not .Response }}
|
||||||
Destination: user
|
Destination: user
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
{{ .Response }}<step>
|
{{ .Response }} <step>
|
||||||
{{- end }}
|
|
@@ -1,13 +1,5 @@
|
|||||||
{{- if .Messages }}
|
{{ if .System }}System: {{ .System }}
|
||||||
{{- if .System }}System: {{ .System }}
|
{{ end }}{{ if .Prompt }}User:
|
||||||
{{ end }}
|
{{ .Prompt }}
|
||||||
{{- range .Messages }}
|
|
||||||
{{- if eq .Role "user" }}User:
|
|
||||||
{{ else if eq .Role "assistant" }}Falcon:
|
|
||||||
{{ end }}{{ .Content }}
|
|
||||||
{{ end }}Falcon:
|
{{ end }}Falcon:
|
||||||
{{ else }}
|
{{ .Response }}
|
||||||
{{ if .System }}{{ .System }}
|
|
||||||
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
|
|
||||||
{{ end }}Assistant: {{ .Response }}
|
|
||||||
{{- end }}
|
|
||||||
|
@@ -1,16 +1,5 @@
|
|||||||
{{- if .Messages }}
|
|
||||||
{{- range $index, $_ := .Messages }}<start_of_turn>
|
|
||||||
{{- if eq .Role "user" }}user
|
|
||||||
{{- if and $.System (eq $index 0) }}
|
|
||||||
{{ $.System }}
|
|
||||||
{{- end }}
|
|
||||||
{{- else if eq .Role "assistant" }}model
|
|
||||||
{{- end }}
|
|
||||||
{{ .Content }}<end_of_turn>
|
|
||||||
{{ end }}<start_of_turn>model
|
|
||||||
{{ else }}
|
|
||||||
<start_of_turn>user
|
<start_of_turn>user
|
||||||
{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}<end_of_turn>
|
{{ if .System }}{{ .System }}
|
||||||
|
{{ end }}{{ .Prompt }}<end_of_turn>
|
||||||
<start_of_turn>model
|
<start_of_turn>model
|
||||||
{{ .Response }}<end_of_turn>
|
{{ .Response }}<end_of_turn>
|
||||||
{{- end }}
|
|
@@ -1,18 +1,4 @@
|
|||||||
{{- if .Messages }}
|
{{ if .System }}System:
|
||||||
{{- if .System }}System:
|
|
||||||
{{ .System }}
|
|
||||||
|
|
||||||
{{ end }}
|
|
||||||
{{- range .Messages }}
|
|
||||||
{{- if eq .Role "user" }}Question:
|
|
||||||
{{- else if eq .Role "assistant" }}Answer:
|
|
||||||
{{- end }}
|
|
||||||
{{ .Content }}
|
|
||||||
|
|
||||||
{{ end }}Answer:
|
|
||||||
{{ else }}
|
|
||||||
{{ if .System }}
|
|
||||||
System:
|
|
||||||
{{ .System }}
|
{{ .System }}
|
||||||
|
|
||||||
{{ end }}{{ if .Prompt }}Question:
|
{{ end }}{{ if .Prompt }}Question:
|
||||||
@@ -20,4 +6,4 @@ System:
|
|||||||
|
|
||||||
{{ end }}Answer:
|
{{ end }}Answer:
|
||||||
{{ .Response }}
|
{{ .Response }}
|
||||||
{{- end }}
|
|
||||||
|
@@ -1,16 +1,6 @@
|
|||||||
{{- if .Messages }}
|
[INST] <<SYS>>
|
||||||
{{- range $index, $_ := .Messages }}
|
{{- if .System }}
|
||||||
{{- if eq .Role "user" }}[INST] {{ if eq $index 0 }}<<SYS>>
|
{{ .System }}
|
||||||
{{- if $.System }}
|
|
||||||
{{ $.System }}
|
|
||||||
{{ end }}<</SYS>>
|
{{ end }}<</SYS>>
|
||||||
|
|
||||||
{{ end }}{{ .Content }}
|
{{ .Prompt }} [/INST] {{ .Response }}</s><s>
|
||||||
{{- else }} [/INST] {{ .Content }}</s><s>
|
|
||||||
{{- end }}
|
|
||||||
{{- end }} [/INST]
|
|
||||||
{{- else }}
|
|
||||||
[INST] <<SYS>>{{ .System }}<</SYS>>
|
|
||||||
|
|
||||||
{{ .Prompt }} [/INST] {{ .Response }}
|
|
||||||
{{- end }}
|
|
@@ -1,19 +1,7 @@
|
|||||||
{{- if .Messages }}
|
|
||||||
{{- if .System }}<|start_header_id|>system<|end_header_id|>
|
|
||||||
|
|
||||||
{{ .System }}<|eot_id|>
|
|
||||||
{{- end }}
|
|
||||||
{{- range .Messages }}<|start_header_id|>{{ .Role }}<|end_header_id|>
|
|
||||||
|
|
||||||
{{ .Content }}<|eot_id|>
|
|
||||||
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
|
||||||
|
|
||||||
{{ else }}
|
|
||||||
{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
{{ .Response }}<|eot_id|>
|
{{ .Response }}<|eot_id|>
|
||||||
{{- end }}
|
|
@@ -1,15 +1,3 @@
|
|||||||
{{- if .Messages }}
|
|
||||||
{{- if .System }}{{ .System }}
|
|
||||||
|
|
||||||
{{ end }}
|
|
||||||
{{- range .Messages }}
|
|
||||||
{{- if eq .Role "user" }}@@ Instruction
|
|
||||||
{{- else if eq .Role "assistant" }}@@ Response
|
|
||||||
{{- end }}
|
|
||||||
{{ .Content }}
|
|
||||||
|
|
||||||
{{ end }}@@ Response
|
|
||||||
{{ else }}
|
|
||||||
{{ if .System }}{{ .System }}
|
{{ if .System }}{{ .System }}
|
||||||
|
|
||||||
{{ end }}{{ if .Prompt }}@@ Instruction
|
{{ end }}{{ if .Prompt }}@@ Instruction
|
||||||
@@ -17,4 +5,4 @@
|
|||||||
|
|
||||||
{{ end }}@@ Response
|
{{ end }}@@ Response
|
||||||
{{ .Response }}
|
{{ .Response }}
|
||||||
{{- end }}
|
|
||||||
|
@@ -1,9 +1,3 @@
|
|||||||
{{- if .Messages }}
|
[INST] {{ if .System }}{{ .System }}
|
||||||
{{- range $index, $_ := .Messages }}
|
|
||||||
{{- if eq .Role "user" }}[INST] {{ if and $.System (eq (len (slice $.Messages $index)) 1) }}{{ $.System }}
|
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}</s>
|
||||||
{{ end }}{{ .Content }}
|
|
||||||
{{- else if eq .Role "assistant" }}[/INST] {{ .Content }}</s>
|
|
||||||
{{- end }}
|
|
||||||
{{- end }}[/INST]
|
|
||||||
{{- else }}[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST] {{ .Response }}
|
|
||||||
{{- end }}
|
|
@@ -1,11 +1 @@
|
|||||||
{{- if .Messages }}
|
{{ if .System }}GPT4 Correct System: {{ .System }}<|end_of_turn|>{{ end }}GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>
|
||||||
{{- if .System }}GPT Correct System: {{ .System }}<|end_of_turn|>
|
|
||||||
{{- end }}
|
|
||||||
{{- range .Messages }}GPT Correct
|
|
||||||
{{- if eq .Role "user" }} User:
|
|
||||||
{{- else if eq .Role "assistant" }} Assistant:
|
|
||||||
{{- end }} {{ .Content }}<|end_of_turn|>
|
|
||||||
{{- end }}GPT Correct Assistant:
|
|
||||||
{{- else }}
|
|
||||||
{{ .System }}<|end_of_turn|>GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>
|
|
||||||
{{- end }}
|
|
@@ -1,15 +1,6 @@
|
|||||||
{{- if .Messages }}
|
|
||||||
{{- if .System }}<|system|>
|
|
||||||
{{ .System }}<|end|>
|
|
||||||
{{ end }}
|
|
||||||
{{- range .Messages }}<|{{ .Role }}|>
|
|
||||||
{{ .Content }}<|end|>
|
|
||||||
{{ end }}<|assistant|>
|
|
||||||
{{ else }}
|
|
||||||
{{ if .System }}<|system|>
|
{{ if .System }}<|system|>
|
||||||
{{ .System }}<|end|>
|
{{ .System }}<|end|>
|
||||||
{{ end }}{{ if .Prompt }}<|user|>
|
{{ end }}{{ if .Prompt }}<|user|>
|
||||||
{{ .Prompt }}<|end|>
|
{{ .Prompt }}<|end|>
|
||||||
{{ end }}<|assistant|>
|
{{ end }}<|assistant|>
|
||||||
{{ .Response }}<|end|>
|
{{ .Response }}<|end|>
|
||||||
{{- end }}
|
|
@@ -1,16 +1,3 @@
|
|||||||
{{- if .Messages }}
|
|
||||||
{{- if .System }}### System:
|
|
||||||
{{ .System }}
|
|
||||||
|
|
||||||
{{ end }}
|
|
||||||
{{- range .Messages }}
|
|
||||||
{{- if eq .Role "user" }}### User:
|
|
||||||
{{ .Content }}
|
|
||||||
{{ else if eq .Role "assistant" }}### Assistant:
|
|
||||||
{{ .Content }}</s>
|
|
||||||
{{ end }}
|
|
||||||
{{ end }}### Assistant:
|
|
||||||
{{ else }}
|
|
||||||
{{ if .System }}### System:
|
{{ if .System }}### System:
|
||||||
{{ .System }}
|
{{ .System }}
|
||||||
|
|
||||||
@@ -18,5 +5,5 @@
|
|||||||
{{ .Prompt }}
|
{{ .Prompt }}
|
||||||
|
|
||||||
{{ end }}### Assistant:
|
{{ end }}### Assistant:
|
||||||
{{ .Response }}
|
{{ .Response }}</s>
|
||||||
{{- end }}
|
|
||||||
|
@@ -1,24 +1,8 @@
|
|||||||
{{- if .Messages }}
|
|
||||||
{{- if .System }}{{ .System }}
|
|
||||||
|
|
||||||
{{ end }}
|
|
||||||
{{- range .Messages }}
|
|
||||||
{{- if eq .Role "user" }}### Instruction
|
|
||||||
{{ .Content }}
|
|
||||||
|
|
||||||
{{ else if eq .Role "assistant" }}### Response
|
|
||||||
{{ .Content }}<|endoftext|>
|
|
||||||
|
|
||||||
{{ end }}
|
|
||||||
{{- end }}### Response
|
|
||||||
{{ else }}
|
|
||||||
{{ if .System }}{{ .System }}
|
{{ if .System }}{{ .System }}
|
||||||
|
|
||||||
{{ end }}{{ if .Prompt }}### Instruction
|
{{ end }}{{ if .Prompt }}### Instruction
|
||||||
{{ .Prompt }}
|
{{ .Prompt }}
|
||||||
|
|
||||||
|
|
||||||
{{ end }}### Response
|
{{ end }}### Response
|
||||||
{{ .Response }}<|endoftext|>
|
{{ .Response }}<|endoftext|>
|
||||||
|
|
||||||
{{- end }}
|
|
@@ -102,8 +102,15 @@ var response = parse.ActionNode{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var funcs = template.FuncMap{
|
||||||
|
"json": func(v any) string {
|
||||||
|
b, _ := json.Marshal(v)
|
||||||
|
return string(b)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
func Parse(s string) (*Template, error) {
|
func Parse(s string) (*Template, error) {
|
||||||
tmpl := template.New("").Option("missingkey=zero")
|
tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
|
||||||
|
|
||||||
tmpl, err := tmpl.Parse(s)
|
tmpl, err := tmpl.Parse(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -127,7 +134,7 @@ func (t *Template) Vars() []string {
|
|||||||
var vars []string
|
var vars []string
|
||||||
for _, tt := range t.Templates() {
|
for _, tt := range t.Templates() {
|
||||||
for _, n := range tt.Root.Nodes {
|
for _, n := range tt.Root.Nodes {
|
||||||
vars = append(vars, parseNode(n)...)
|
vars = append(vars, Identifiers(n)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -143,54 +150,130 @@ func (t *Template) Vars() []string {
|
|||||||
|
|
||||||
type Values struct {
|
type Values struct {
|
||||||
Messages []api.Message
|
Messages []api.Message
|
||||||
|
api.Tools
|
||||||
|
Prompt string
|
||||||
|
Suffix string
|
||||||
|
|
||||||
|
// forceLegacy is a flag used to test compatibility with legacy templates
|
||||||
|
forceLegacy bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
|
||||||
|
var walk func(parse.Node) parse.Node
|
||||||
|
walk = func(n parse.Node) parse.Node {
|
||||||
|
if fn(n) {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
switch t := n.(type) {
|
||||||
|
case *parse.ListNode:
|
||||||
|
for _, c := range t.Nodes {
|
||||||
|
if n := walk(c); n != nil {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case *parse.BranchNode:
|
||||||
|
for _, n := range []*parse.ListNode{t.List, t.ElseList} {
|
||||||
|
if n != nil {
|
||||||
|
if n := walk(n); n != nil {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case *parse.IfNode:
|
||||||
|
return walk(&t.BranchNode)
|
||||||
|
case *parse.WithNode:
|
||||||
|
return walk(&t.BranchNode)
|
||||||
|
case *parse.RangeNode:
|
||||||
|
return walk(&t.BranchNode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if n := walk(t.Tree.Root); n != nil {
|
||||||
|
return (&template.Template{
|
||||||
|
Tree: &parse.Tree{
|
||||||
|
Root: &parse.ListNode{
|
||||||
|
Nodes: []parse.Node{n},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}).Funcs(funcs)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Template) Execute(w io.Writer, v Values) error {
|
func (t *Template) Execute(w io.Writer, v Values) error {
|
||||||
system, collated := collate(v.Messages)
|
system, messages := collate(v.Messages)
|
||||||
if slices.Contains(t.Vars(), "messages") {
|
if v.Prompt != "" && v.Suffix != "" {
|
||||||
|
return t.Template.Execute(w, map[string]any{
|
||||||
|
"Prompt": v.Prompt,
|
||||||
|
"Suffix": v.Suffix,
|
||||||
|
"Response": "",
|
||||||
|
})
|
||||||
|
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
||||||
return t.Template.Execute(w, map[string]any{
|
return t.Template.Execute(w, map[string]any{
|
||||||
"System": system,
|
"System": system,
|
||||||
"Messages": collated,
|
"Messages": messages,
|
||||||
|
"Tools": v.Tools,
|
||||||
|
"Response": "",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
system = ""
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
var prompt, response string
|
var prompt, response string
|
||||||
for i, m := range collated {
|
for _, m := range messages {
|
||||||
if m.Role == "user" {
|
execute := func() error {
|
||||||
prompt = m.Content
|
|
||||||
} else {
|
|
||||||
response = m.Content
|
|
||||||
}
|
|
||||||
|
|
||||||
if i != len(collated)-1 && prompt != "" && response != "" {
|
|
||||||
if err := t.Template.Execute(&b, map[string]any{
|
if err := t.Template.Execute(&b, map[string]any{
|
||||||
"System": "",
|
"System": system,
|
||||||
"Prompt": prompt,
|
"Prompt": prompt,
|
||||||
"Response": response,
|
"Response": response,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
system = ""
|
||||||
prompt = ""
|
prompt = ""
|
||||||
response = ""
|
response = ""
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch m.Role {
|
||||||
|
case "system":
|
||||||
|
if prompt != "" || response != "" {
|
||||||
|
if err := execute(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
system = m.Content
|
||||||
|
case "user":
|
||||||
|
if response != "" {
|
||||||
|
if err := execute(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
prompt = m.Content
|
||||||
|
case "assistant":
|
||||||
|
response = m.Content
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var cut bool
|
var cut bool
|
||||||
tree := t.Template.Copy()
|
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
|
||||||
// for the last message, cut everything after "{{ .Response }}"
|
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
|
||||||
tree.Root.Nodes = slices.DeleteFunc(tree.Root.Nodes, func(n parse.Node) bool {
|
|
||||||
if slices.Contains(parseNode(n), "Response") {
|
|
||||||
cut = true
|
cut = true
|
||||||
}
|
}
|
||||||
|
|
||||||
return cut
|
return cut
|
||||||
})
|
})
|
||||||
|
|
||||||
if err := template.Must(template.New("").AddParseTree("", tree)).Execute(&b, map[string]any{
|
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
|
||||||
"System": system,
|
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
|
||||||
"Prompt": prompt,
|
"System": system,
|
||||||
|
"Prompt": prompt,
|
||||||
|
"Response": "",
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -199,25 +282,16 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
type messages []*api.Message
|
|
||||||
|
|
||||||
// collate messages based on role. consecutive messages of the same role are merged
|
// collate messages based on role. consecutive messages of the same role are merged
|
||||||
// into a single message. collate also pulls out and merges messages with Role == "system"
|
// into a single message. collate also collects and returns all system messages.
|
||||||
// which are templated separately. As a side effect, it mangles message content adding image
|
// collate mutates message content adding image tags ([img-%d]) as needed
|
||||||
// tags ([img-%d]) as needed
|
func collate(msgs []api.Message) (string, []*api.Message) {
|
||||||
func collate(msgs []api.Message) (system string, collated messages) {
|
|
||||||
var n int
|
var n int
|
||||||
|
|
||||||
|
var system []string
|
||||||
|
var collated []*api.Message
|
||||||
for i := range msgs {
|
for i := range msgs {
|
||||||
msg := msgs[i]
|
msg := msgs[i]
|
||||||
if msg.Role == "system" {
|
|
||||||
if system != "" {
|
|
||||||
system += "\n\n"
|
|
||||||
}
|
|
||||||
|
|
||||||
system += msg.Content
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for range msg.Images {
|
for range msg.Images {
|
||||||
imageTag := fmt.Sprintf("[img-%d]", n)
|
imageTag := fmt.Sprintf("[img-%d]", n)
|
||||||
if !strings.Contains(msg.Content, "[img]") {
|
if !strings.Contains(msg.Content, "[img]") {
|
||||||
@@ -228,6 +302,10 @@ func collate(msgs []api.Message) (system string, collated messages) {
|
|||||||
n++
|
n++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if msg.Role == "system" {
|
||||||
|
system = append(system, msg.Content)
|
||||||
|
}
|
||||||
|
|
||||||
if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
|
if len(collated) > 0 && collated[len(collated)-1].Role == msg.Role {
|
||||||
collated[len(collated)-1].Content += "\n\n" + msg.Content
|
collated[len(collated)-1].Content += "\n\n" + msg.Content
|
||||||
} else {
|
} else {
|
||||||
@@ -235,54 +313,119 @@ func collate(msgs []api.Message) (system string, collated messages) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return strings.Join(system, "\n\n"), collated
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseNode(n parse.Node) []string {
|
// Identifiers walks the node tree returning any identifiers it finds along the way
|
||||||
|
func Identifiers(n parse.Node) []string {
|
||||||
switch n := n.(type) {
|
switch n := n.(type) {
|
||||||
|
case *parse.ListNode:
|
||||||
|
var names []string
|
||||||
|
for _, n := range n.Nodes {
|
||||||
|
names = append(names, Identifiers(n)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return names
|
||||||
|
case *parse.TemplateNode:
|
||||||
|
return Identifiers(n.Pipe)
|
||||||
case *parse.ActionNode:
|
case *parse.ActionNode:
|
||||||
return parseNode(n.Pipe)
|
return Identifiers(n.Pipe)
|
||||||
|
case *parse.BranchNode:
|
||||||
|
names := Identifiers(n.Pipe)
|
||||||
|
for _, n := range []*parse.ListNode{n.List, n.ElseList} {
|
||||||
|
if n != nil {
|
||||||
|
names = append(names, Identifiers(n)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return names
|
||||||
case *parse.IfNode:
|
case *parse.IfNode:
|
||||||
names := parseNode(n.Pipe)
|
return Identifiers(&n.BranchNode)
|
||||||
names = append(names, parseNode(n.List)...)
|
|
||||||
if n.ElseList != nil {
|
|
||||||
names = append(names, parseNode(n.ElseList)...)
|
|
||||||
}
|
|
||||||
return names
|
|
||||||
case *parse.RangeNode:
|
case *parse.RangeNode:
|
||||||
names := parseNode(n.Pipe)
|
return Identifiers(&n.BranchNode)
|
||||||
names = append(names, parseNode(n.List)...)
|
|
||||||
if n.ElseList != nil {
|
|
||||||
names = append(names, parseNode(n.ElseList)...)
|
|
||||||
}
|
|
||||||
return names
|
|
||||||
case *parse.WithNode:
|
case *parse.WithNode:
|
||||||
names := parseNode(n.Pipe)
|
return Identifiers(&n.BranchNode)
|
||||||
names = append(names, parseNode(n.List)...)
|
|
||||||
if n.ElseList != nil {
|
|
||||||
names = append(names, parseNode(n.ElseList)...)
|
|
||||||
}
|
|
||||||
return names
|
|
||||||
case *parse.PipeNode:
|
case *parse.PipeNode:
|
||||||
var names []string
|
var names []string
|
||||||
for _, c := range n.Cmds {
|
for _, c := range n.Cmds {
|
||||||
for _, a := range c.Args {
|
for _, a := range c.Args {
|
||||||
names = append(names, parseNode(a)...)
|
names = append(names, Identifiers(a)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return names
|
|
||||||
case *parse.ListNode:
|
|
||||||
var names []string
|
|
||||||
for _, n := range n.Nodes {
|
|
||||||
names = append(names, parseNode(n)...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return names
|
return names
|
||||||
case *parse.FieldNode:
|
case *parse.FieldNode:
|
||||||
return n.Ident
|
return n.Ident
|
||||||
case *parse.TemplateNode:
|
case *parse.VariableNode:
|
||||||
return parseNode(n.Pipe)
|
return n.Ident
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// deleteNode walks the node list and deletes nodes that match the predicate
|
||||||
|
// this is currently to remove the {{ .Response }} node from templates
|
||||||
|
func deleteNode(n parse.Node, fn func(parse.Node) bool) parse.Node {
|
||||||
|
var walk func(n parse.Node) parse.Node
|
||||||
|
walk = func(n parse.Node) parse.Node {
|
||||||
|
if fn(n) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch t := n.(type) {
|
||||||
|
case *parse.ListNode:
|
||||||
|
var nodes []parse.Node
|
||||||
|
for _, c := range t.Nodes {
|
||||||
|
if n := walk(c); n != nil {
|
||||||
|
nodes = append(nodes, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Nodes = nodes
|
||||||
|
return t
|
||||||
|
case *parse.IfNode:
|
||||||
|
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
|
||||||
|
case *parse.WithNode:
|
||||||
|
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
|
||||||
|
case *parse.RangeNode:
|
||||||
|
t.BranchNode = *(walk(&t.BranchNode).(*parse.BranchNode))
|
||||||
|
case *parse.BranchNode:
|
||||||
|
t.List = walk(t.List).(*parse.ListNode)
|
||||||
|
if t.ElseList != nil {
|
||||||
|
t.ElseList = walk(t.ElseList).(*parse.ListNode)
|
||||||
|
}
|
||||||
|
case *parse.ActionNode:
|
||||||
|
n := walk(t.Pipe)
|
||||||
|
if n == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Pipe = n.(*parse.PipeNode)
|
||||||
|
case *parse.PipeNode:
|
||||||
|
var commands []*parse.CommandNode
|
||||||
|
for _, c := range t.Cmds {
|
||||||
|
var args []parse.Node
|
||||||
|
for _, a := range c.Args {
|
||||||
|
if n := walk(a); n != nil {
|
||||||
|
args = append(args, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(args) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Args = args
|
||||||
|
commands = append(commands, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(commands) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Cmds = commands
|
||||||
|
}
|
||||||
|
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
return walk(n)
|
||||||
|
}
|
||||||
|
@@ -105,8 +105,8 @@ func TestTemplate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for n, tt := range cases {
|
for n, tt := range cases {
|
||||||
|
var actual bytes.Buffer
|
||||||
t.Run(n, func(t *testing.T) {
|
t.Run(n, func(t *testing.T) {
|
||||||
var actual bytes.Buffer
|
|
||||||
if err := tmpl.Execute(&actual, Values{Messages: tt}); err != nil {
|
if err := tmpl.Execute(&actual, Values{Messages: tt}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -116,7 +116,34 @@ func TestTemplate(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(actual.Bytes(), expect); diff != "" {
|
bts := actual.Bytes()
|
||||||
|
|
||||||
|
if slices.Contains([]string{"chatqa.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && bts[len(bts)-1] == ' ' {
|
||||||
|
t.Log("removing trailing space from output")
|
||||||
|
bts = bts[:len(bts)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(bts, expect); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("legacy", func(t *testing.T) {
|
||||||
|
t.Skip("legacy outputs are currently default outputs")
|
||||||
|
var legacy bytes.Buffer
|
||||||
|
if err := tmpl.Execute(&legacy, Values{Messages: tt, forceLegacy: true}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
legacyBytes := legacy.Bytes()
|
||||||
|
if slices.Contains([]string{"chatqa.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && legacyBytes[len(legacyBytes)-1] == ' ' {
|
||||||
|
t.Log("removing trailing space from legacy output")
|
||||||
|
legacyBytes = legacyBytes[:len(legacyBytes)-1]
|
||||||
|
} else if slices.Contains([]string{"codellama-70b-instruct.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl"}, match) {
|
||||||
|
t.Skip("legacy outputs cannot be compared to messages outputs")
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(legacyBytes, actual.Bytes()); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -135,7 +162,24 @@ func TestParse(t *testing.T) {
|
|||||||
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
|
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
|
||||||
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
|
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
|
||||||
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
|
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
|
||||||
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
|
{`{{- range .Messages }}
|
||||||
|
{{- if eq .Role "system" }}SYSTEM:
|
||||||
|
{{- else if eq .Role "user" }}USER:
|
||||||
|
{{- else if eq .Role "assistant" }}ASSISTANT:
|
||||||
|
{{- end }} {{ .Content }}
|
||||||
|
{{- end }}`, []string{"content", "messages", "role"}},
|
||||||
|
{`{{- if .Messages }}
|
||||||
|
{{- range .Messages }}<|im_start|>{{ .Role }}
|
||||||
|
{{ .Content }}<|im_end|>
|
||||||
|
{{ end }}<|im_start|>assistant
|
||||||
|
{{ else -}}
|
||||||
|
{{ if .System }}<|im_start|>system
|
||||||
|
{{ .System }}<|im_end|>
|
||||||
|
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||||
|
{{ .Prompt }}<|im_end|>
|
||||||
|
{{ end }}<|im_start|>assistant
|
||||||
|
{{ .Response }}<|im_end|>
|
||||||
|
{{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
@@ -145,9 +189,8 @@ func TestParse(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
vars := tmpl.Vars()
|
if diff := cmp.Diff(tmpl.Vars(), tt.vars); diff != "" {
|
||||||
if !slices.Equal(tt.vars, vars) {
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
t.Errorf("expected %v, got %v", tt.vars, vars)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -167,12 +210,17 @@ func TestExecuteWithMessages(t *testing.T) {
|
|||||||
{
|
{
|
||||||
"mistral",
|
"mistral",
|
||||||
[]template{
|
[]template{
|
||||||
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
|
{"no response", `[INST] {{ if .System }}{{ .System }}
|
||||||
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
|
||||||
{"messages", `{{- range $index, $_ := .Messages }}
|
{{ end }}{{ .Prompt }}[/INST] `},
|
||||||
{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
|
{"response", `[INST] {{ if .System }}{{ .System }}
|
||||||
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
|
|
||||||
{{- end }}
|
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
||||||
|
{"messages", `[INST] {{ if .System }}{{ .System }}
|
||||||
|
|
||||||
|
{{ end }}
|
||||||
|
{{- range .Messages }}
|
||||||
|
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
|
||||||
{{- end }}`},
|
{{- end }}`},
|
||||||
},
|
},
|
||||||
Values{
|
Values{
|
||||||
@@ -187,13 +235,17 @@ func TestExecuteWithMessages(t *testing.T) {
|
|||||||
{
|
{
|
||||||
"mistral system",
|
"mistral system",
|
||||||
[]template{
|
[]template{
|
||||||
{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
|
{"no response", `[INST] {{ if .System }}{{ .System }}
|
||||||
{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
|
||||||
{"messages", `
|
{{ end }}{{ .Prompt }}[/INST] `},
|
||||||
{{- range $index, $_ := .Messages }}
|
{"response", `[INST] {{ if .System }}{{ .System }}
|
||||||
{{- if eq .Role "user" }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}{{ "\n\n" }}
|
|
||||||
{{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
|
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
||||||
{{- end }}
|
{"messages", `[INST] {{ if .System }}{{ .System }}
|
||||||
|
|
||||||
|
{{ end }}
|
||||||
|
{{- range .Messages }}
|
||||||
|
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
|
||||||
{{- end }}`},
|
{{- end }}`},
|
||||||
},
|
},
|
||||||
Values{
|
Values{
|
||||||
@@ -204,9 +256,9 @@ func TestExecuteWithMessages(t *testing.T) {
|
|||||||
{Role: "user", Content: "What is your name?"},
|
{Role: "user", Content: "What is your name?"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
`[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
|
`[INST] You are a helpful assistant!
|
||||||
|
|
||||||
What is your name?[/INST] `,
|
Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"chatml",
|
"chatml",
|
||||||
@@ -220,12 +272,9 @@ What is your name?[/INST] `,
|
|||||||
{{ .Response }}<|im_end|>
|
{{ .Response }}<|im_end|>
|
||||||
`},
|
`},
|
||||||
{"messages", `
|
{"messages", `
|
||||||
{{- range $index, $_ := .Messages }}
|
{{- range $index, $_ := .Messages }}<|im_start|>{{ .Role }}
|
||||||
{{- if and (eq .Role "user") (eq (len (slice $.Messages $index)) 1) $.System }}<|im_start|>system
|
{{ .Content }}<|im_end|>
|
||||||
{{ $.System }}<|im_end|>{{ "\n" }}
|
{{ end }}<|im_start|>assistant
|
||||||
{{- end }}<|im_start|>{{ .Role }}
|
|
||||||
{{ .Content }}<|im_end|>{{ "\n" }}
|
|
||||||
{{- end }}<|im_start|>assistant
|
|
||||||
`},
|
`},
|
||||||
},
|
},
|
||||||
Values{
|
Values{
|
||||||
@@ -236,12 +285,12 @@ What is your name?[/INST] `,
|
|||||||
{Role: "user", Content: "What is your name?"},
|
{Role: "user", Content: "What is your name?"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
`<|im_start|>user
|
`<|im_start|>system
|
||||||
|
You are a helpful assistant!<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
Hello friend!<|im_end|>
|
Hello friend!<|im_end|>
|
||||||
<|im_start|>assistant
|
<|im_start|>assistant
|
||||||
Hello human!<|im_end|>
|
Hello human!<|im_end|>
|
||||||
<|im_start|>system
|
|
||||||
You are a helpful assistant!<|im_end|>
|
|
||||||
<|im_start|>user
|
<|im_start|>user
|
||||||
What is your name?<|im_end|>
|
What is your name?<|im_end|>
|
||||||
<|im_start|>assistant
|
<|im_start|>assistant
|
||||||
@@ -258,9 +307,11 @@ What is your name?<|im_end|>
|
|||||||
`},
|
`},
|
||||||
{"messages", `
|
{"messages", `
|
||||||
{{- range .Messages }}
|
{{- range .Messages }}
|
||||||
{{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }}
|
{{- if eq .Role "user" }}Question: {{ .Content }}
|
||||||
{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }}
|
|
||||||
{{- end }}
|
{{ else if eq .Role "assistant" }}Answer: {{ .Content }}
|
||||||
|
|
||||||
|
{{ end }}
|
||||||
{{- end }}Answer: `},
|
{{- end }}Answer: `},
|
||||||
},
|
},
|
||||||
Values{
|
Values{
|
||||||
@@ -300,11 +351,46 @@ Answer: `,
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if b.String() != tt.expected {
|
if diff := cmp.Diff(b.String(), tt.expected); diff != "" {
|
||||||
t.Errorf("expected\n%s,\ngot\n%s", tt.expected, b.String())
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExecuteWithSuffix(t *testing.T) {
|
||||||
|
tmpl, err := Parse(`{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
|
||||||
|
{{- else }}{{ .Prompt }}
|
||||||
|
{{- end }}`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
values Values
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"message", Values{Messages: []api.Message{{Role: "user", Content: "hello"}}}, "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prompt suffix", Values{Prompt: "def add(", Suffix: "return x"}, "<PRE> def add( <SUF>return x <MID>",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := tmpl.Execute(&b, tt.values); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(b.String(), tt.expect); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -1,4 +1,6 @@
|
|||||||
You are a helpful assistant.### Instruction:
|
You are a helpful assistant.
|
||||||
|
|
||||||
|
### Instruction:
|
||||||
Hello, how are you?
|
Hello, how are you?
|
||||||
|
|
||||||
### Response:
|
### Response:
|
||||||
|
@@ -9,3 +9,4 @@ Source: system
|
|||||||
I'd like to show off how chat templating works! <step> Source: assistant
|
I'd like to show off how chat templating works! <step> Source: assistant
|
||||||
Destination: user
|
Destination: user
|
||||||
|
|
||||||
|
|
@@ -3,3 +3,4 @@ Source: user
|
|||||||
Hello, how are you? <step> Source: assistant
|
Hello, how are you? <step> Source: assistant
|
||||||
Destination: user
|
Destination: user
|
||||||
|
|
||||||
|
|
@@ -7,3 +7,4 @@ Source: user
|
|||||||
I'd like to show off how chat templating works! <step> Source: assistant
|
I'd like to show off how chat templating works! <step> Source: assistant
|
||||||
Destination: user
|
Destination: user
|
||||||
|
|
||||||
|
|
@@ -2,4 +2,6 @@
|
|||||||
You are a helpful assistant.
|
You are a helpful assistant.
|
||||||
<</SYS>>
|
<</SYS>>
|
||||||
|
|
||||||
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] I'd like to show off how chat templating works! [/INST]
|
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] <<SYS>><</SYS>>
|
||||||
|
|
||||||
|
I'd like to show off how chat templating works! [/INST]
|
@@ -1,3 +1,5 @@
|
|||||||
[INST] <<SYS>><</SYS>>
|
[INST] <<SYS>><</SYS>>
|
||||||
|
|
||||||
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] I'd like to show off how chat templating works! [/INST]
|
Hello, how are you? [/INST] I'm doing great. How can I help you today?</s><s>[INST] <<SYS>><</SYS>>
|
||||||
|
|
||||||
|
I'd like to show off how chat templating works! [/INST]
|
@@ -1,2 +1,3 @@
|
|||||||
[INST] Hello, how are you?[/INST] I'm doing great. How can I help you today?</s>[INST] You are a helpful assistant.
|
[INST] You are a helpful assistant.
|
||||||
I'd like to show off how chat templating works![/INST]
|
|
||||||
|
Hello, how are you?[/INST] I'm doing great. How can I help you today?</s>[INST] I'd like to show off how chat templating works![/INST]
|
@@ -1 +1 @@
|
|||||||
GPT Correct System: You are a helpful assistant.<|end_of_turn|>GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT Correct Assistant:
|
GPT4 Correct System: You are a helpful assistant.<|end_of_turn|>GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT4 Correct Assistant:
|
2
template/testdata/openchat.gotmpl/user
vendored
2
template/testdata/openchat.gotmpl/user
vendored
@@ -1 +1 @@
|
|||||||
GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant:
|
GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant:
|
@@ -1 +1 @@
|
|||||||
GPT Correct User: Hello, how are you?<|end_of_turn|>GPT Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT Correct Assistant:
|
GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>GPT4 Correct Assistant:
|
@@ -1,14 +1,4 @@
|
|||||||
{{- if .Messages }}
|
|
||||||
{{- if .System }}{{ .System }}
|
|
||||||
|
|
||||||
{{ end }}
|
|
||||||
{{- range .Messages }}
|
|
||||||
{{- if eq .Role "user" }}USER: {{ .Content }}
|
|
||||||
{{ else if eq .Role "assistant" }}ASSISTANT: {{ .Content }}</s>
|
|
||||||
{{ end }}
|
|
||||||
{{- end }}ASSISTANT:
|
|
||||||
{{- else }}
|
|
||||||
{{ if .System }}{{ .System }}
|
{{ if .System }}{{ .System }}
|
||||||
|
|
||||||
{{ end }}{{ if .Prompt }}USER: {{ .Prompt }}
|
{{ end }}{{ if .Prompt }}USER: {{ .Prompt }}
|
||||||
{{ end }}ASSISTANT: {{ .Response }}
|
{{ end }}ASSISTANT: {{ .Response }}</s>
|
||||||
{{- end }}
|
|
||||||
|
@@ -1,15 +1,6 @@
|
|||||||
{{- if .Messages }}
|
|
||||||
{{- if .System }}<|system|>
|
|
||||||
{{ .System }}</s>
|
|
||||||
{{ end }}
|
|
||||||
{{- range .Messages }}<|{{ .Role }}|>
|
|
||||||
{{ .Content }}</s>
|
|
||||||
{{ end }}<|assistant|>
|
|
||||||
{{ else }}
|
|
||||||
{{ if .System }}<|system|>
|
{{ if .System }}<|system|>
|
||||||
{{ .System }}</s>
|
{{ .System }}</s>
|
||||||
{{ end }}{{ if .Prompt }}<|user|>
|
{{ end }}{{ if .Prompt }}<|user|>
|
||||||
{{ .Prompt }}</s>
|
{{ .Prompt }}</s>
|
||||||
{{ end }}<|assistant|>
|
{{ end }}<|assistant|>
|
||||||
{{ .Response }}</s>
|
{{ .Response }}</s>
|
||||||
{{- end }}
|
|
Reference in New Issue
Block a user