Compare commits
143 Commits
v0.1.49-rc
...
roy-embed-
Author | SHA1 | Date | |
---|---|---|---|
![]() |
2647a0e443 | ||
![]() |
ec4c35fe99 | ||
![]() |
f5e3939220 | ||
![]() |
ae27d9dcfd | ||
![]() |
37096790a7 | ||
![]() |
997c903884 | ||
![]() |
c8af3c2d96 | ||
![]() |
455e61170d | ||
![]() |
4de1370a9d | ||
![]() |
bbf8f102ee | ||
![]() |
bb46bbcf5e | ||
![]() |
ac33aa7d37 | ||
![]() |
a6cd8f6169 | ||
![]() |
c78089263a | ||
![]() |
3e5ea035d5 | ||
![]() |
5d604eec5b | ||
![]() |
db0968f30c | ||
![]() |
9b60a038e5 | ||
![]() |
83a0cb8d88 | ||
![]() |
c0648233f2 | ||
![]() |
d835368eb8 | ||
![]() |
5784c05397 | ||
![]() |
f14aa5435d | ||
![]() |
f8fedbda20 | ||
![]() |
b3e5491e41 | ||
![]() |
cc269ba094 | ||
![]() |
a3c20e3f18 | ||
![]() |
80ee9b5e47 | ||
![]() |
5534f2cc6a | ||
![]() |
d321297d8a | ||
![]() |
06e5d74e34 | ||
![]() |
5d707e6fd5 | ||
![]() |
283948c83b | ||
![]() |
1475eab95f | ||
![]() |
20090f3172 | ||
![]() |
69a2d4ccff | ||
![]() |
e8b954c646 | ||
![]() |
c57317cbf0 | ||
![]() |
51b2fd299c | ||
![]() |
d0634b1596 | ||
![]() |
43606d6d6a | ||
![]() |
70b1010fa5 | ||
![]() |
84e5721f3a | ||
![]() |
319fb1ce03 | ||
![]() |
b255445557 | ||
![]() |
f02f83660c | ||
![]() |
b23424bb3c | ||
![]() |
5fd6988126 | ||
![]() |
5b82960df8 | ||
![]() |
cc9a252d8c | ||
![]() |
d281a6e603 | ||
![]() |
154f6f45d4 | ||
![]() |
0d41623b52 | ||
![]() |
c279f96371 | ||
![]() |
499e87c9ba | ||
![]() |
cd0853f2d5 | ||
![]() |
d290e87513 | ||
![]() |
97c20ede33 | ||
![]() |
5a83f79afd | ||
![]() |
987dbab0b0 | ||
![]() |
a8388beb94 | ||
![]() |
5afbb60fc4 | ||
![]() |
4cb5d7decc | ||
![]() |
8eac50dd4f | ||
![]() |
4a565cbf94 | ||
![]() |
64039df6d7 | ||
![]() |
7ac6d462ec | ||
![]() |
ef5136a745 | ||
![]() |
8288ec8824 | ||
![]() |
d02bbebb11 | ||
![]() |
224337b32f | ||
![]() |
9e35d9bbee | ||
![]() |
b9f5e16c80 | ||
![]() |
e9f7f36029 | ||
![]() |
057d31861e | ||
![]() |
f7ee012300 | ||
![]() |
1ed0aa8fea | ||
![]() |
ef98803d63 | ||
![]() |
02fea420e5 | ||
![]() |
22c5451fc2 | ||
![]() |
ebc529cbb3 | ||
![]() |
23ebbaa46e | ||
![]() |
9ac0a7a50b | ||
![]() |
e5c65a85df | ||
![]() |
33627331a3 | ||
![]() |
36c87c433b | ||
![]() |
179737feb7 | ||
![]() |
47353f5ee4 | ||
![]() |
10e768826c | ||
![]() |
5056bb9c01 | ||
![]() |
c4cf8ad559 | ||
![]() |
57ec6901eb | ||
![]() |
e64f9ebb44 | ||
![]() |
791650ddef | ||
![]() |
efbf41ed81 | ||
![]() |
cf15589851 | ||
![]() |
19753c18c0 | ||
![]() |
41be28096a | ||
![]() |
37a570f962 | ||
![]() |
5a739ff4cb | ||
![]() |
4e262eb2a8 | ||
![]() |
4cfcbc328f | ||
![]() |
79292ff3e0 | ||
![]() |
8ea500441d | ||
![]() |
b50c818623 | ||
![]() |
b99e750b62 | ||
![]() |
1f50356e8e | ||
![]() |
22c81f62ec | ||
![]() |
73e2c8f68f | ||
![]() |
f4408219e9 | ||
![]() |
2d1e3c3229 | ||
![]() |
4918fae535 | ||
![]() |
0aff67877e | ||
![]() |
f6f759fc5f | ||
![]() |
9544a57ee4 | ||
![]() |
b51e3b63ac | ||
![]() |
6bbbc50f10 | ||
![]() |
9bbddc37a7 | ||
![]() |
e4ff73297d | ||
![]() |
b44320db13 | ||
![]() |
0bacb30007 | ||
![]() |
53da2c6965 | ||
![]() |
d8def1ff94 | ||
![]() |
571dc61955 | ||
![]() |
0e09c380fc | ||
![]() |
0ee87615c7 | ||
![]() |
f8241bfba3 | ||
![]() |
4607c70641 | ||
![]() |
c12f1c5b99 | ||
![]() |
a08f20d910 | ||
![]() |
6cea036027 | ||
![]() |
5796bfc401 | ||
![]() |
f1a379aa56 | ||
![]() |
9ae146993e | ||
![]() |
e0348d3fe8 | ||
![]() |
2cc854f8cb | ||
![]() |
5304b765b2 | ||
![]() |
fb6cbc02fb | ||
![]() |
326363b3a7 | ||
![]() |
ac7a842e55 | ||
![]() |
2c3fe1fd97 | ||
![]() |
269ed6e6a2 | ||
![]() |
784bf88b0d |
12
.github/workflows/release.yaml
vendored
12
.github/workflows/release.yaml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
security set-keychain-settings -lut 3600 build.keychain
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: "stable"
|
||||
cache: true
|
||||
- name: Build Darwin
|
||||
env:
|
||||
@@ -87,7 +87,7 @@ jobs:
|
||||
write-host "plugin installed"
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: "stable"
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
@@ -141,13 +141,13 @@ jobs:
|
||||
write-host "plugin installed"
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: "stable"
|
||||
cache: true
|
||||
- name: 'Install ROCm'
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading AMD HIP Installer"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
write-host "Installing AMD HIP"
|
||||
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
||||
write-host "Completed AMD HIP"
|
||||
@@ -218,7 +218,7 @@ jobs:
|
||||
write-host "plugin installed"
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: "stable"
|
||||
cache: true
|
||||
- name: 'Install CUDA'
|
||||
run: |
|
||||
@@ -306,7 +306,7 @@ jobs:
|
||||
write-host "plugin installed"
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: "stable"
|
||||
cache: true
|
||||
- run: go get
|
||||
- uses: actions/download-artifact@v4
|
||||
|
14
.github/workflows/test.yaml
vendored
14
.github/workflows/test.yaml
vendored
@@ -63,7 +63,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: "stable"
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
@@ -126,7 +126,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
rocm-version:
|
||||
- '6.1.1'
|
||||
- '6.1.2'
|
||||
runs-on: linux
|
||||
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
|
||||
steps:
|
||||
@@ -163,13 +163,13 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: "stable"
|
||||
cache: true
|
||||
- name: 'Install ROCm'
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading AMD HIP Installer"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
write-host "Installing AMD HIP"
|
||||
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
||||
write-host "Completed AMD HIP"
|
||||
@@ -200,7 +200,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: "stable"
|
||||
cache: true
|
||||
- name: 'Install CUDA'
|
||||
run: |
|
||||
@@ -255,7 +255,7 @@ jobs:
|
||||
submodules: recursive
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: "stable"
|
||||
cache: false
|
||||
- run: |
|
||||
case ${{ matrix.arch }} in
|
||||
@@ -297,7 +297,7 @@ jobs:
|
||||
submodules: recursive
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: "stable"
|
||||
cache: true
|
||||
- run: |
|
||||
case ${{ matrix.arch }} in
|
||||
|
@@ -1,8 +1,8 @@
|
||||
ARG GOLANG_VERSION=1.22.1
|
||||
ARG GOLANG_VERSION=1.22.5
|
||||
ARG CMAKE_VERSION=3.22.1
|
||||
# this CUDA_VERSION corresponds with the one specified in docs/gpu.md
|
||||
ARG CUDA_VERSION=11.3.1
|
||||
ARG ROCM_VERSION=6.1.1
|
||||
ARG ROCM_VERSION=6.1.2
|
||||
|
||||
# Copy the minimal context we need to run the generate scripts
|
||||
FROM scratch AS llm-code
|
||||
|
@@ -64,7 +64,8 @@ Here are some example models that can be downloaded:
|
||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||
| Solar | 10.7B | 6.1GB | `ollama run solar` |
|
||||
|
||||
> Note: You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||
> [!NOTE]
|
||||
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||
|
||||
## Customize a model
|
||||
|
||||
@@ -293,6 +294,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
||||
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
||||
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
|
||||
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
|
||||
- [AI Studio](https://github.com/MindWorkAI/AI-Studio)
|
||||
- [Sidellama](https://github.com/gyopak/sidellama) (browser-based LLM client)
|
||||
- [LLMStack](https://github.com/trypromptly/LLMStack) (No-code multi-agent framework to build LLM agents and workflows)
|
||||
|
||||
### Terminal
|
||||
|
||||
|
@@ -347,7 +347,16 @@ func (c *Client) Heartbeat(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Embeddings generates embeddings from a model.
|
||||
// Embed generates embeddings from a model.
|
||||
func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
|
||||
var resp EmbedResponse
|
||||
if err := c.do(ctx, http.MethodPost, "/api/embed", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// Embeddings generates an embedding from a model.
|
||||
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
|
||||
var resp EmbeddingResponse
|
||||
if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
|
||||
|
102
api/types.go
102
api/types.go
@@ -47,6 +47,9 @@ type GenerateRequest struct {
|
||||
// Prompt is the textual prompt to send to the model.
|
||||
Prompt string `json:"prompt"`
|
||||
|
||||
// Suffix is the text that comes after the inserted text.
|
||||
Suffix string `json:"suffix"`
|
||||
|
||||
// System overrides the model's default system message/prompt.
|
||||
System string `json:"system"`
|
||||
|
||||
@@ -97,17 +100,80 @@ type ChatRequest struct {
|
||||
// followin the request.
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
// Tools is an optional list of tools the model has access to.
|
||||
Tools `json:"tools,omitempty"`
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
|
||||
func (t Tools) String() string {
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// Message is a single message in a chat sequence. The message contains the
|
||||
// role ("system", "user", or "assistant"), the content and an optional list
|
||||
// of images.
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||
type Alias Message
|
||||
var a Alias
|
||||
if err := json.Unmarshal(b, &a); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*m = Message(a)
|
||||
m.Role = strings.ToLower(m.Role)
|
||||
return nil
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
Function ToolCallFunction `json:"function"`
|
||||
}
|
||||
|
||||
type ToolCallFunction struct {
|
||||
Name string `json:"name"`
|
||||
Arguments ToolCallFunctionArguments `json:"arguments"`
|
||||
}
|
||||
|
||||
type ToolCallFunctionArguments map[string]any
|
||||
|
||||
func (t *ToolCallFunctionArguments) String() string {
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Function ToolFunction `json:"function"`
|
||||
}
|
||||
|
||||
type ToolFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters struct {
|
||||
Type string `json:"type"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
} `json:"parameters"`
|
||||
}
|
||||
|
||||
func (t *ToolFunction) String() string {
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||
@@ -173,6 +239,30 @@ type Runner struct {
|
||||
NumThread int `json:"num_thread,omitempty"`
|
||||
}
|
||||
|
||||
// EmbedRequest is the request passed to [Client.Embed].
|
||||
type EmbedRequest struct {
|
||||
// Model is the model name.
|
||||
Model string `json:"model"`
|
||||
|
||||
// Input is the input to embed.
|
||||
Input any `json:"input"`
|
||||
|
||||
// KeepAlive controls how long the model will stay loaded in memory following
|
||||
// this request.
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
Truncate *bool `json:"truncate,omitempty"`
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
// EmbedResponse is the response from [Client.Embed].
|
||||
type EmbedResponse struct {
|
||||
Model string `json:"model"`
|
||||
Embeddings [][]float32 `json:"embeddings"`
|
||||
}
|
||||
|
||||
// EmbeddingRequest is the request passed to [Client.Embeddings].
|
||||
type EmbeddingRequest struct {
|
||||
// Model is the model name.
|
||||
@@ -219,8 +309,10 @@ type DeleteRequest struct {
|
||||
|
||||
// ShowRequest is the request passed to [Client.Show].
|
||||
type ShowRequest struct {
|
||||
Model string `json:"model"`
|
||||
System string `json:"system"`
|
||||
Model string `json:"model"`
|
||||
System string `json:"system"`
|
||||
|
||||
// Template is deprecated
|
||||
Template string `json:"template"`
|
||||
Verbose bool `json:"verbose"`
|
||||
|
||||
|
@@ -208,3 +208,26 @@ func TestUseMmapFormatParams(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_UnmarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{`{"role": "USER", "content": "Hello!"}`, "user"},
|
||||
{`{"role": "System", "content": "Initialization complete."}`, "system"},
|
||||
{`{"role": "assistant", "content": "How can I help you?"}`, "assistant"},
|
||||
{`{"role": "TOOl", "content": "Access granted."}`, "tool"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
var msg Message
|
||||
if err := json.Unmarshal([]byte(test.input), &msg); err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if msg.Role != test.expected {
|
||||
t.Errorf("role not lowercased: got %v, expected %v", msg.Role, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -127,6 +127,10 @@ Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\models"
|
||||
Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history"
|
||||
; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved
|
||||
|
||||
[InstallDelete]
|
||||
Type: filesandordirs; Name: "{%TEMP}\ollama*"
|
||||
Type: filesandordirs; Name: "{%LOCALAPPDATA}\Programs\Ollama"
|
||||
|
||||
[Messages]
|
||||
WizardReady=Ollama Windows Preview
|
||||
ReadyLabel1=%nLet's get you up and running with your own large language models.
|
||||
|
@@ -843,7 +843,6 @@ type runOptions struct {
|
||||
WordWrap bool
|
||||
Format string
|
||||
System string
|
||||
Template string
|
||||
Images []api.ImageData
|
||||
Options map[string]interface{}
|
||||
MultiModal bool
|
||||
@@ -1037,7 +1036,6 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
Images: opts.Images,
|
||||
Format: opts.Format,
|
||||
System: opts.System,
|
||||
Template: opts.Template,
|
||||
Options: opts.Options,
|
||||
KeepAlive: opts.KeepAlive,
|
||||
}
|
||||
@@ -1346,7 +1344,6 @@ func NewCLI() *cobra.Command {
|
||||
envVars["OLLAMA_TMPDIR"],
|
||||
envVars["OLLAMA_FLASH_ATTENTION"],
|
||||
envVars["OLLAMA_LLM_LIBRARY"],
|
||||
envVars["OLLAMA_MAX_VRAM"],
|
||||
})
|
||||
default:
|
||||
appendEnvDocs(cmd, envs)
|
||||
|
@@ -27,7 +27,6 @@ const (
|
||||
MultilineNone MultilineState = iota
|
||||
MultilinePrompt
|
||||
MultilineSystem
|
||||
MultilineTemplate
|
||||
)
|
||||
|
||||
func loadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
@@ -94,7 +93,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
|
||||
fmt.Fprintln(os.Stderr, " /set system <string> Set system message")
|
||||
fmt.Fprintln(os.Stderr, " /set template <string> Set prompt template")
|
||||
fmt.Fprintln(os.Stderr, " /set history Enable history")
|
||||
fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
|
||||
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
|
||||
@@ -204,10 +202,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
case MultilineTemplate:
|
||||
opts.Template = sb.String()
|
||||
fmt.Println("Set prompt template.")
|
||||
sb.Reset()
|
||||
}
|
||||
|
||||
multiline = MultilineNone
|
||||
@@ -326,17 +320,13 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
}
|
||||
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
|
||||
opts.Options[args[2]] = fp[args[2]]
|
||||
case "system", "template":
|
||||
case "system":
|
||||
if len(args) < 3 {
|
||||
usageSet()
|
||||
continue
|
||||
}
|
||||
|
||||
if args[1] == "system" {
|
||||
multiline = MultilineSystem
|
||||
} else if args[1] == "template" {
|
||||
multiline = MultilineTemplate
|
||||
}
|
||||
multiline = MultilineSystem
|
||||
|
||||
line := strings.Join(args[2:], " ")
|
||||
line, ok := strings.CutPrefix(line, `"""`)
|
||||
@@ -356,23 +346,17 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
continue
|
||||
}
|
||||
|
||||
if args[1] == "system" {
|
||||
opts.System = sb.String() // for display in modelfile
|
||||
newMessage := api.Message{Role: "system", Content: sb.String()}
|
||||
// Check if the slice is not empty and the last message is from 'system'
|
||||
if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
|
||||
// Replace the last message
|
||||
opts.Messages[len(opts.Messages)-1] = newMessage
|
||||
} else {
|
||||
opts.Messages = append(opts.Messages, newMessage)
|
||||
}
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
} else if args[1] == "template" {
|
||||
opts.Template = sb.String()
|
||||
fmt.Println("Set prompt template.")
|
||||
sb.Reset()
|
||||
opts.System = sb.String() // for display in modelfile
|
||||
newMessage := api.Message{Role: "system", Content: sb.String()}
|
||||
// Check if the slice is not empty and the last message is from 'system'
|
||||
if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
|
||||
// Replace the last message
|
||||
opts.Messages[len(opts.Messages)-1] = newMessage
|
||||
} else {
|
||||
opts.Messages = append(opts.Messages, newMessage)
|
||||
}
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
|
||||
sb.Reset()
|
||||
continue
|
||||
@@ -393,7 +377,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
req := &api.ShowRequest{
|
||||
Name: opts.Model,
|
||||
System: opts.System,
|
||||
Template: opts.Template,
|
||||
Options: opts.Options,
|
||||
}
|
||||
resp, err := client.Show(cmd.Context(), req)
|
||||
@@ -437,12 +420,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Println("No system message was specified for this model.")
|
||||
}
|
||||
case "template":
|
||||
switch {
|
||||
case opts.Template != "":
|
||||
fmt.Println(opts.Template + "\n")
|
||||
case resp.Template != "":
|
||||
if resp.Template != "" {
|
||||
fmt.Println(resp.Template)
|
||||
default:
|
||||
} else {
|
||||
fmt.Println("No prompt template was specified for this model.")
|
||||
}
|
||||
default:
|
||||
@@ -536,10 +516,6 @@ func buildModelfile(opts runOptions) string {
|
||||
fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
|
||||
}
|
||||
|
||||
if opts.Template != "" {
|
||||
fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template)
|
||||
}
|
||||
|
||||
keys := make([]string, 0)
|
||||
for k := range opts.Options {
|
||||
keys = append(keys, k)
|
||||
|
@@ -59,7 +59,6 @@ func TestModelfileBuilder(t *testing.T) {
|
||||
opts := runOptions{
|
||||
Model: "hork",
|
||||
System: "You are part horse and part shark, but all hork. Do horklike things",
|
||||
Template: "This is a template.",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hey there hork!"},
|
||||
{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
|
||||
@@ -75,7 +74,6 @@ func TestModelfileBuilder(t *testing.T) {
|
||||
mf := buildModelfile(opts)
|
||||
expectedModelfile := `FROM {{.Model}}
|
||||
SYSTEM """{{.System}}"""
|
||||
TEMPLATE """{{.Template}}"""
|
||||
PARAMETER penalize_newline false
|
||||
PARAMETER seed 42
|
||||
PARAMETER stop [hi there]
|
||||
@@ -97,7 +95,6 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark."""
|
||||
mf = buildModelfile(opts)
|
||||
expectedModelfile = `FROM {{.ParentModel}}
|
||||
SYSTEM """{{.System}}"""
|
||||
TEMPLATE """{{.Template}}"""
|
||||
PARAMETER penalize_newline false
|
||||
PARAMETER seed 42
|
||||
PARAMETER stop [hi there]
|
||||
|
@@ -71,6 +71,11 @@ func (m *MistralModel) WriteGGUF(ws io.WriteSeeker) error {
|
||||
"tokenizer.ggml.unknown_token_id": uint32(0),
|
||||
}
|
||||
|
||||
if m.Params.HeadDimension > 0 {
|
||||
kv["llama.attention.key_length"] = uint32(m.Params.HeadDimension)
|
||||
kv["llama.attention.value_length"] = uint32(m.Params.HeadDimension)
|
||||
}
|
||||
|
||||
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
|
||||
}
|
||||
|
||||
|
201
docs/api.md
201
docs/api.md
@@ -40,6 +40,7 @@ Generate a response for a given prompt with a provided model. This is a streamin
|
||||
|
||||
- `model`: (required) the [model name](#model-names)
|
||||
- `prompt`: the prompt to generate a response for
|
||||
- `suffix`: the text after the model response
|
||||
- `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`)
|
||||
|
||||
Advanced parameters (optional):
|
||||
@@ -57,7 +58,8 @@ Advanced parameters (optional):
|
||||
|
||||
Enable JSON mode by setting the `format` parameter to `json`. This will structure the response as a valid JSON object. See the JSON mode [example](#request-json-mode) below.
|
||||
|
||||
> Note: it's important to instruct the model to use JSON in the `prompt`. Otherwise, the model may generate large amounts whitespace.
|
||||
> [!IMPORTANT]
|
||||
> It's important to instruct the model to use JSON in the `prompt`. Otherwise, the model may generate large amounts whitespace.
|
||||
|
||||
### Examples
|
||||
|
||||
@@ -148,8 +150,44 @@ If `stream` is set to `false`, the response will be a single JSON object:
|
||||
}
|
||||
```
|
||||
|
||||
#### Request (with suffix)
|
||||
|
||||
##### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"model": "codellama:code",
|
||||
"prompt": "def compute_gcd(a, b):",
|
||||
"suffix": " return result",
|
||||
"options": {
|
||||
"temperature": 0
|
||||
},
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
##### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "codellama:code",
|
||||
"created_at": "2024-07-22T20:47:51.147561Z",
|
||||
"response": "\n if a == 0:\n return b\n else:\n return compute_gcd(b % a, a)\n\ndef compute_lcm(a, b):\n result = (a * b) / compute_gcd(a, b)\n",
|
||||
"done": true,
|
||||
"done_reason": "stop",
|
||||
"context": [...],
|
||||
"total_duration": 1162761250,
|
||||
"load_duration": 6683708,
|
||||
"prompt_eval_count": 17,
|
||||
"prompt_eval_duration": 201222000,
|
||||
"eval_count": 63,
|
||||
"eval_duration": 953997000
|
||||
}
|
||||
```
|
||||
|
||||
#### Request (JSON mode)
|
||||
|
||||
> [!IMPORTANT]
|
||||
> When `format` is set to `json`, the output will always be a well-formed JSON object. It's important to also instruct the model to respond in JSON.
|
||||
|
||||
##### Request
|
||||
@@ -380,12 +418,14 @@ Generate the next message in a chat with a provided model. This is a streaming e
|
||||
|
||||
- `model`: (required) the [model name](#model-names)
|
||||
- `messages`: the messages of the chat, this can be used to keep a chat memory
|
||||
- `tools`: tools for the model to use if supported. Requires `stream` to be set to `false`
|
||||
|
||||
The `message` object has the following fields:
|
||||
|
||||
- `role`: the role of the message, either `system`, `user` or `assistant`
|
||||
- `role`: the role of the message, either `system`, `user`, `assistant`, or `tool`
|
||||
- `content`: the content of the message
|
||||
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
|
||||
- `tool_calls` (optional): a list of tools the model wants to use
|
||||
|
||||
Advanced parameters (optional):
|
||||
|
||||
@@ -622,6 +662,79 @@ curl http://localhost:11434/api/chat -d '{
|
||||
}
|
||||
```
|
||||
|
||||
#### Chat request (with tools)
|
||||
|
||||
##### Request
|
||||
|
||||
```
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "mistral",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather today in Paris?"
|
||||
}
|
||||
],
|
||||
"stream": false,
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather for a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The location to get the weather for, e.g. San Francisco, CA"
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
},
|
||||
"required": ["location", "format"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
##### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "mistral:7b-instruct-v0.3-q4_K_M",
|
||||
"created_at": "2024-07-22T20:33:28.123648Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": {
|
||||
"format": "celsius",
|
||||
"location": "Paris, FR"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"done_reason": "stop",
|
||||
"done": true,
|
||||
"total_duration": 885095291,
|
||||
"load_duration": 3753500,
|
||||
"prompt_eval_count": 122,
|
||||
"prompt_eval_duration": 328493000,
|
||||
"eval_count": 33,
|
||||
"eval_duration": 552222000
|
||||
}
|
||||
```
|
||||
|
||||
## Create a Model
|
||||
|
||||
```shell
|
||||
@@ -1026,7 +1139,7 @@ If `stream` is set to `false`, then the response is a single JSON object:
|
||||
## Generate Embeddings
|
||||
|
||||
```shell
|
||||
POST /api/embeddings
|
||||
POST /api/embed
|
||||
```
|
||||
|
||||
Generate embeddings from a model
|
||||
@@ -1034,10 +1147,11 @@ Generate embeddings from a model
|
||||
### Parameters
|
||||
|
||||
- `model`: name of model to generate embeddings from
|
||||
- `prompt`: text to generate embeddings for
|
||||
- `input`: text or list of text to generate embeddings for
|
||||
|
||||
Advanced parameters:
|
||||
|
||||
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
|
||||
@@ -1046,9 +1160,9 @@ Advanced parameters:
|
||||
#### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/embeddings -d '{
|
||||
curl http://localhost:11434/api/embed -d '{
|
||||
"model": "all-minilm",
|
||||
"prompt": "Here is an article about llamas..."
|
||||
"input": "Why is the sky blue?"
|
||||
}'
|
||||
```
|
||||
|
||||
@@ -1056,10 +1170,35 @@ curl http://localhost:11434/api/embeddings -d '{
|
||||
|
||||
```json
|
||||
{
|
||||
"embedding": [
|
||||
0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313,
|
||||
0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281
|
||||
]
|
||||
"model": "all-minilm",
|
||||
"embeddings": [[
|
||||
0.010071029, -0.0017594862, 0.05007221, 0.04692972, 0.054916814,
|
||||
0.008599704, 0.105441414, -0.025878139, 0.12958129, 0.031952348
|
||||
]]
|
||||
}
|
||||
```
|
||||
|
||||
#### Request (Multiple input)
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/embed -d '{
|
||||
"model": "all-minilm",
|
||||
"input": ["Why is the sky blue?", "Why is the grass green?"]
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "all-minilm",
|
||||
"embeddings": [[
|
||||
0.010071029, -0.0017594862, 0.05007221, 0.04692972, 0.054916814,
|
||||
0.008599704, 0.105441414, -0.025878139, 0.12958129, 0.031952348
|
||||
],[
|
||||
-0.0098027075, 0.06042469, 0.025257962, -0.006364387, 0.07272725,
|
||||
0.017194884, 0.09032035, -0.051705178, 0.09951512, 0.09072481
|
||||
]]
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1106,3 +1245,45 @@ A single JSON object will be returned.
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Generate Embedding
|
||||
|
||||
> Note: this endpoint has been superseded by `/api/embed`
|
||||
|
||||
```shell
|
||||
POST /api/embeddings
|
||||
```
|
||||
|
||||
Generate embeddings from a model
|
||||
|
||||
### Parameters
|
||||
|
||||
- `model`: name of model to generate embeddings from
|
||||
- `prompt`: text to generate embeddings for
|
||||
|
||||
Advanced parameters:
|
||||
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
|
||||
### Examples
|
||||
|
||||
#### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/embeddings -d '{
|
||||
"model": "all-minilm",
|
||||
"prompt": "Here is an article about llamas..."
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"embedding": [
|
||||
0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313,
|
||||
0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281
|
||||
]
|
||||
}
|
||||
```
|
||||
|
@@ -272,4 +272,4 @@ The following server settings may be used to adjust how Ollama handles concurren
|
||||
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
|
||||
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
|
||||
|
||||
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
|
||||
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
|
15
docs/gpu.md
15
docs/gpu.md
@@ -46,13 +46,24 @@ sudo modprobe nvidia_uvm`
|
||||
|
||||
## AMD Radeon
|
||||
Ollama supports the following AMD GPUs:
|
||||
|
||||
### Linux Support
|
||||
| Family | Cards and accelerators |
|
||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` |
|
||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` |
|
||||
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` |
|
||||
|
||||
### Overrides
|
||||
### Windows Support
|
||||
With ROCm v6.1, the following GPUs are supported on Windows.
|
||||
|
||||
| Family | Cards and accelerators |
|
||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` |
|
||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
|
||||
|
||||
|
||||
### Overrides on Linux
|
||||
Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In
|
||||
some cases you can force the system to try to use a similar LLVM target that is
|
||||
close. For example The Radeon RX 5400 is `gfx1034` (also known as 10.3.4)
|
||||
@@ -63,7 +74,7 @@ would set `HSA_OVERRIDE_GFX_VERSION="10.3.0"` as an environment variable for the
|
||||
server. If you have an unsupported AMD GPU you can experiment using the list of
|
||||
supported types below.
|
||||
|
||||
At this time, the known supported GPU types are the following LLVM Targets.
|
||||
At this time, the known supported GPU types on linux are the following LLVM Targets.
|
||||
This table shows some example GPUs that map to these LLVM targets:
|
||||
| **LLVM Target** | **An Example GPU** |
|
||||
|-----------------|---------------------|
|
||||
|
@@ -1,6 +1,7 @@
|
||||
# Ollama Model File
|
||||
|
||||
> Note: `Modelfile` syntax is in development
|
||||
> [!NOTE]
|
||||
> `Modelfile` syntax is in development
|
||||
|
||||
A model file is the blueprint to create and share models with Ollama.
|
||||
|
||||
|
@@ -78,8 +78,8 @@ curl http://localhost:11434/v1/chat/completions \
|
||||
- [x] Streaming
|
||||
- [x] JSON mode
|
||||
- [x] Reproducible outputs
|
||||
- [x] Tools (streaming support coming soon)
|
||||
- [ ] Vision
|
||||
- [ ] Function calling
|
||||
- [ ] Logprobs
|
||||
|
||||
#### Supported request fields
|
||||
@@ -97,16 +97,12 @@ curl http://localhost:11434/v1/chat/completions \
|
||||
- [x] `temperature`
|
||||
- [x] `top_p`
|
||||
- [x] `max_tokens`
|
||||
- [ ] `logit_bias`
|
||||
- [ ] `tools`
|
||||
- [x] `tools`
|
||||
- [ ] `tool_choice`
|
||||
- [ ] `logit_bias`
|
||||
- [ ] `user`
|
||||
- [ ] `n`
|
||||
|
||||
#### Notes
|
||||
|
||||
- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
|
||||
|
||||
## Models
|
||||
|
||||
Before using a model, pull it locally `ollama pull`:
|
||||
|
173
docs/template.md
Normal file
173
docs/template.md
Normal file
@@ -0,0 +1,173 @@
|
||||
# Template
|
||||
|
||||
Ollama provides a powerful templating engine backed by Go's built-in templating engine to construct prompts for your large language model. This feature is a valuable tool to get the most out of your models.
|
||||
|
||||
## Basic Template Structure
|
||||
|
||||
A basic Go template consists of three main parts:
|
||||
|
||||
* **Layout**: The overall structure of the template.
|
||||
* **Variables**: Placeholders for dynamic data that will be replaced with actual values when the template is rendered.
|
||||
* **Functions**: Custom functions or logic that can be used to manipulate the template's content.
|
||||
|
||||
Here's an example of a simple chat template:
|
||||
|
||||
```gotmpl
|
||||
{{- range .Messages }}
|
||||
{{ .Role }}: {{ .Content }}
|
||||
{{- end }}
|
||||
```
|
||||
|
||||
In this example, we have:
|
||||
|
||||
* A basic messages structure (layout)
|
||||
* Three variables: `Messages`, `Role`, and `Content` (variables)
|
||||
* A custom function (action) that iterates over an array of items (`range .Messages`) and displays each item
|
||||
|
||||
## Adding templates to your model
|
||||
|
||||
By default, models imported into Ollama have a default template of `{{ .Prompt }}`, i.e. user inputs are sent verbatim to the LLM. This is appropriate for text or code completion models but lacks essential markers for chat or instruction models.
|
||||
|
||||
Omitting a template in these models puts the responsibility of correctly templating input onto the user. Adding a template allows users to easily get the best results from the model.
|
||||
|
||||
To add templates in your model, you'll need to add a `TEMPLATE` command to the Modelfile. Here's an example using Meta's Llama 3.
|
||||
|
||||
```dockerfile
|
||||
FROM llama3
|
||||
|
||||
TEMPLATE """{{- if .System }}<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
{{ .System }}<|eot_id|>
|
||||
{{- end }}
|
||||
{{- range .Messages }}<|start_header_id|>{{ .Role }}<|end_header_id|>
|
||||
|
||||
{{ .Content }}<|eot_id|>
|
||||
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
"""
|
||||
```
|
||||
|
||||
## Variables
|
||||
|
||||
`System` (string): system prompt
|
||||
|
||||
`Prompt` (string): user prompt
|
||||
|
||||
`Response` (string): assistant response
|
||||
|
||||
`Suffix` (string): text inserted after the assistant's response
|
||||
|
||||
`Messages` (list): list of messages
|
||||
|
||||
`Messages[].Role` (string): role which can be one of `system`, `user`, `assistant`, or `tool`
|
||||
|
||||
`Messages[].Content` (string): message content
|
||||
|
||||
`Messages[].ToolCalls` (list): list of tools the model wants to call
|
||||
|
||||
`Messages[].ToolCalls[].Function` (object): function to call
|
||||
|
||||
`Messages[].ToolCalls[].Function.Name` (string): function name
|
||||
|
||||
`Messages[].ToolCalls[].Function.Arguments` (map): mapping of argument name to argument value
|
||||
|
||||
`Tools` (list): list of tools the model can access
|
||||
|
||||
`Tools[].Type` (string): schema type. `type` is always `function`
|
||||
|
||||
`Tools[].Function` (object): function definition
|
||||
|
||||
`Tools[].Function.Name` (string): function name
|
||||
|
||||
`Tools[].Function.Description` (string): function description
|
||||
|
||||
`Tools[].Function.Parameters` (object): function parameters
|
||||
|
||||
`Tools[].Function.Parameters.Type` (string): schema type. `type` is always `object`
|
||||
|
||||
`Tools[].Function.Parameters.Required` (list): list of required properties
|
||||
|
||||
`Tools[].Function.Parameters.Properties` (map): mapping of property name to property definition
|
||||
|
||||
`Tools[].Function.Parameters.Properties[].Type` (string): property type
|
||||
|
||||
`Tools[].Function.Parameters.Properties[].Description` (string): property description
|
||||
|
||||
`Tools[].Function.Parameters.Properties[].Enum` (list): list of valid values
|
||||
|
||||
## Tips and Best Practices
|
||||
|
||||
Keep the following tips and best practices in mind when working with Go templates:
|
||||
|
||||
* **Be mindful of dot**: Control flow structures like `range` and `with` changes the value `.`
|
||||
* **Out-of-scope variables**: Use `$.` to reference variables not currently in scope, starting from the root
|
||||
* **Whitespace control**: Use `-` to trim leading (`{{-`) and trailing (`-}}`) whitespace
|
||||
|
||||
## Examples
|
||||
|
||||
### Example Messages
|
||||
|
||||
#### ChatML
|
||||
|
||||
ChatML is a popular template format. It can be used for models such as Databrick's DBRX, Intel's Neural Chat, and Microsoft's Orca 2.
|
||||
|
||||
```gotmpl
|
||||
{{- if .System }}<|im_start|>system
|
||||
{{ .System }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- range .Messages }}<|im_start|>{{ .Role }}
|
||||
{{ .Content }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ else }}
|
||||
{{ if .System }}<|im_start|>system
|
||||
{{ .System }}<|im_end|>
|
||||
```
|
||||
|
||||
### Example Tools
|
||||
|
||||
Tools support can be added to a model by adding a `{{ .Tools }}` node to the template. This feature is useful for models trained to call external tools and can a powerful tool for retrieving real-time data or performing complex tasks.
|
||||
|
||||
#### Mistral
|
||||
|
||||
Mistral v0.3 and Mixtral 8x22B supports tool calling.
|
||||
|
||||
```gotmpl
|
||||
{{- range $index, $_ := .Messages }}
|
||||
{{- if eq .Role "user" }}
|
||||
{{- if and (le (len (slice $.Messages $index)) 2) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
|
||||
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
|
||||
|
||||
{{ end }}{{ .Content }}[/INST]
|
||||
{{- else if eq .Role "assistant" }}
|
||||
{{- if .Content }} {{ .Content }}</s>
|
||||
{{- else if .ToolCalls }}[TOOL_CALLS] [
|
||||
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
|
||||
{{- end }}]</s>
|
||||
{{- end }}
|
||||
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
```
|
||||
|
||||
### Example Fill-in-Middle
|
||||
|
||||
Fill-in-middle support can be added to a model by adding a `{{ .Suffix }}` node to the template. This feature is useful for models that are trained to generate text in the middle of user input, such as code completion models.
|
||||
|
||||
#### CodeLlama
|
||||
|
||||
CodeLlama [7B](https://ollama.com/library/codellama:7b-code) and [13B](https://ollama.com/library/codellama:13b-code) code completion models support fill-in-middle.
|
||||
|
||||
```gotmpl
|
||||
<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> CodeLlama 34B and 70B code completion and all instruct and Python fine-tuned models do not support fill-in-middle.
|
||||
|
||||
#### Codestral
|
||||
|
||||
Codestral [22B](https://ollama.com/library/codestral:22b) supports fill-in-middle.
|
||||
|
||||
```gotmpl
|
||||
[SUFFIX]{{ .Suffix }}[PREFIX] {{ .Prompt }}
|
||||
```
|
@@ -43,8 +43,6 @@ var (
|
||||
MaxRunners int
|
||||
// Set via OLLAMA_MAX_QUEUE in the environment
|
||||
MaxQueuedRequests int
|
||||
// Set via OLLAMA_MAX_VRAM in the environment
|
||||
MaxVRAM uint64
|
||||
// Set via OLLAMA_MODELS in the environment
|
||||
ModelsDir string
|
||||
// Set via OLLAMA_NOHISTORY in the environment
|
||||
@@ -89,7 +87,6 @@ func AsMap() map[string]EnvVar {
|
||||
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
|
||||
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"},
|
||||
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
|
||||
"OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"},
|
||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"},
|
||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
|
||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
|
||||
@@ -194,16 +191,6 @@ func LoadConfig() {
|
||||
|
||||
TmpDir = clean("OLLAMA_TMPDIR")
|
||||
|
||||
userLimit := clean("OLLAMA_MAX_VRAM")
|
||||
if userLimit != "" {
|
||||
avail, err := strconv.ParseUint(userLimit, 10, 64)
|
||||
if err != nil {
|
||||
slog.Error("invalid setting, ignoring", "OLLAMA_MAX_VRAM", userLimit, "error", err)
|
||||
} else {
|
||||
MaxVRAM = avail
|
||||
}
|
||||
}
|
||||
|
||||
LLMLibrary = clean("OLLAMA_LLM_LIBRARY")
|
||||
|
||||
if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" {
|
||||
|
3
go.mod
3
go.mod
@@ -18,6 +18,7 @@ require (
|
||||
require (
|
||||
github.com/agnivade/levenshtein v1.1.1
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
@@ -71,7 +72,7 @@ require (
|
||||
golang.org/x/net v0.25.0 // indirect
|
||||
golang.org/x/sys v0.20.0
|
||||
golang.org/x/term v0.20.0
|
||||
golang.org/x/text v0.15.0 // indirect
|
||||
golang.org/x/text v0.15.0
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
@@ -49,9 +49,17 @@ func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
}
|
||||
|
||||
func commonAMDValidateLibDir() (string, error) {
|
||||
// We try to favor system paths first, so that we can wire up the subprocess to use
|
||||
// the system version. Only use our bundled version if the system version doesn't work
|
||||
// This gives users a more recovery options if versions have subtle problems at runtime
|
||||
// Favor our bundled version
|
||||
|
||||
// Installer payload location if we're running the installed binary
|
||||
exe, err := os.Executable()
|
||||
if err == nil {
|
||||
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
|
||||
if rocmLibUsable(rocmTargetDir) {
|
||||
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
|
||||
return rocmTargetDir, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Prefer explicit HIP env var
|
||||
hipPath := os.Getenv("HIP_PATH")
|
||||
@@ -87,14 +95,5 @@ func commonAMDValidateLibDir() (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Installer payload location if we're running the installed binary
|
||||
exe, err := os.Executable()
|
||||
if err == nil {
|
||||
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
|
||||
if rocmLibUsable(rocmTargetDir) {
|
||||
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
|
||||
return rocmTargetDir, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
||||
}
|
||||
|
@@ -33,9 +33,10 @@ type HipLib struct {
|
||||
}
|
||||
|
||||
func NewHipLib() (*HipLib, error) {
|
||||
h, err := windows.LoadLibrary("amdhip64.dll")
|
||||
// At runtime we depend on v6, so discover GPUs with the same library for a consistent set of GPUs
|
||||
h, err := windows.LoadLibrary("amdhip64_6.dll")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to load amdhip64.dll: %w", err)
|
||||
return nil, fmt.Errorf("unable to load amdhip64_6.dll, please make sure to upgrade to the latest amd driver: %w", err)
|
||||
}
|
||||
hl := &HipLib{}
|
||||
hl.dll = h
|
||||
@@ -84,9 +85,8 @@ func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
|
||||
}
|
||||
|
||||
slog.Debug("hipDriverGetVersion", "version", version)
|
||||
// TODO - this isn't actually right, but the docs claim hipDriverGetVersion isn't accurate anyway...
|
||||
driverMajor = version / 1000
|
||||
driverMinor = (version - (driverMajor * 1000)) / 10
|
||||
driverMajor = version / 10000000
|
||||
driverMinor = (version - (driverMajor * 10000000)) / 100000
|
||||
|
||||
return driverMajor, driverMinor, nil
|
||||
}
|
||||
|
@@ -22,8 +22,8 @@ const (
|
||||
|
||||
var (
|
||||
// Used to validate if the given ROCm lib is usable
|
||||
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
|
||||
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
|
||||
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // This is not sufficient to discern v5 vs v6
|
||||
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\6.1\\bin"} // TODO glob?
|
||||
)
|
||||
|
||||
func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
@@ -35,12 +35,11 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
}
|
||||
defer hl.Release()
|
||||
|
||||
// TODO - this reports incorrect version information, so omitting for now
|
||||
// driverMajor, driverMinor, err := hl.AMDDriverVersion()
|
||||
// if err != nil {
|
||||
// // For now this is benign, but we may eventually need to fail compatibility checks
|
||||
// slog.Debug("error looking up amd driver version", "error", err)
|
||||
// }
|
||||
driverMajor, driverMinor, err := hl.AMDDriverVersion()
|
||||
if err != nil {
|
||||
// For now this is benign, but we may eventually need to fail compatibility checks
|
||||
slog.Debug("error looking up amd driver version", "error", err)
|
||||
}
|
||||
|
||||
// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
|
||||
count := hl.HipGetDeviceCount()
|
||||
@@ -93,7 +92,8 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
continue
|
||||
}
|
||||
if gfxOverride == "" {
|
||||
if !slices.Contains[[]string, string](supported, gfx) {
|
||||
// Strip off Target Features when comparing
|
||||
if !slices.Contains[[]string, string](supported, strings.Split(gfx, ":")[0]) {
|
||||
slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
|
||||
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
||||
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
|
||||
@@ -132,10 +132,8 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
MinimumMemory: rocmMinimumMemory,
|
||||
Name: name,
|
||||
Compute: gfx,
|
||||
|
||||
// TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve
|
||||
// DriverMajor: driverMajor,
|
||||
// DriverMinor: driverMinor,
|
||||
DriverMajor: driverMajor,
|
||||
DriverMinor: driverMinor,
|
||||
},
|
||||
index: i,
|
||||
}
|
||||
|
30
gpu/gpu.go
30
gpu/gpu.go
@@ -274,6 +274,28 @@ func GetGPUInfo() GpuInfoList {
|
||||
gpuInfo.DriverMajor = driverMajor
|
||||
gpuInfo.DriverMinor = driverMinor
|
||||
|
||||
// query the management library as well so we can record any skew between the two
|
||||
// which represents overhead on the GPU we must set aside on subsequent updates
|
||||
if cHandles.nvml != nil {
|
||||
C.nvml_get_free(*cHandles.nvml, C.int(gpuInfo.index), &memInfo.free, &memInfo.total, &memInfo.used)
|
||||
if memInfo.err != nil {
|
||||
slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
} else {
|
||||
if memInfo.free != 0 && uint64(memInfo.free) > gpuInfo.FreeMemory {
|
||||
gpuInfo.OSOverhead = uint64(memInfo.free) - gpuInfo.FreeMemory
|
||||
slog.Info("detected OS VRAM overhead",
|
||||
"id", gpuInfo.ID,
|
||||
"library", gpuInfo.Library,
|
||||
"compute", gpuInfo.Compute,
|
||||
"driver", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor),
|
||||
"name", gpuInfo.Name,
|
||||
"overhead", format.HumanBytes2(gpuInfo.OSOverhead),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
||||
cudaGPUs = append(cudaGPUs, gpuInfo)
|
||||
}
|
||||
@@ -338,14 +360,17 @@ func GetGPUInfo() GpuInfoList {
|
||||
"before",
|
||||
"total", format.HumanBytes2(cpus[0].TotalMemory),
|
||||
"free", format.HumanBytes2(cpus[0].FreeMemory),
|
||||
"free_swap", format.HumanBytes2(cpus[0].FreeSwap),
|
||||
),
|
||||
slog.Group(
|
||||
"now",
|
||||
"total", format.HumanBytes2(mem.TotalMemory),
|
||||
"free", format.HumanBytes2(mem.FreeMemory),
|
||||
"free_swap", format.HumanBytes2(mem.FreeSwap),
|
||||
),
|
||||
)
|
||||
cpus[0].FreeMemory = mem.FreeMemory
|
||||
cpus[0].FreeSwap = mem.FreeSwap
|
||||
}
|
||||
|
||||
var memInfo C.mem_info_t
|
||||
@@ -374,9 +399,14 @@ func GetGPUInfo() GpuInfoList {
|
||||
slog.Warn("error looking up nvidia GPU memory")
|
||||
continue
|
||||
}
|
||||
if cHandles.nvml != nil && gpu.OSOverhead > 0 {
|
||||
// When using the management library update based on recorded overhead
|
||||
memInfo.free -= C.uint64_t(gpu.OSOverhead)
|
||||
}
|
||||
slog.Debug("updating cuda memory data",
|
||||
"gpu", gpu.ID,
|
||||
"name", gpu.Name,
|
||||
"overhead", format.HumanBytes2(gpu.OSOverhead),
|
||||
slog.Group(
|
||||
"before",
|
||||
"total", format.HumanBytes2(gpu.TotalMemory),
|
||||
|
@@ -56,7 +56,8 @@ func GetCPUInfo() GpuInfoList {
|
||||
func GetCPUMem() (memInfo, error) {
|
||||
return memInfo{
|
||||
TotalMemory: uint64(C.getPhysicalMemory()),
|
||||
FreeMemory: 0,
|
||||
FreeMemory: uint64(C.getFreeMemory()),
|
||||
// FreeSwap omitted as Darwin uses dynamic paging
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@@ -2,3 +2,4 @@
|
||||
#include <stdint.h>
|
||||
uint64_t getRecommendedMaxVRAM();
|
||||
uint64_t getPhysicalMemory();
|
||||
uint64_t getFreeMemory();
|
||||
|
@@ -1,4 +1,5 @@
|
||||
// go:build darwin
|
||||
#import <Foundation/Foundation.h>
|
||||
#import <mach/mach.h>
|
||||
#include "gpu_info_darwin.h"
|
||||
|
||||
uint64_t getRecommendedMaxVRAM() {
|
||||
@@ -8,6 +9,27 @@ uint64_t getRecommendedMaxVRAM() {
|
||||
return result;
|
||||
}
|
||||
|
||||
// getPhysicalMemory returns the total physical memory in bytes
|
||||
uint64_t getPhysicalMemory() {
|
||||
return [[NSProcessInfo processInfo] physicalMemory];
|
||||
return [NSProcessInfo processInfo].physicalMemory;
|
||||
}
|
||||
|
||||
// getFreeMemory returns the total free memory in bytes, including inactive
|
||||
// memory that can be reclaimed by the system.
|
||||
uint64_t getFreeMemory() {
|
||||
mach_port_t host_port = mach_host_self();
|
||||
mach_msg_type_number_t host_size = sizeof(vm_statistics64_data_t) / sizeof(integer_t);
|
||||
vm_size_t pagesize;
|
||||
vm_statistics64_data_t vm_stat;
|
||||
|
||||
host_page_size(host_port, &pagesize);
|
||||
if (host_statistics64(host_port, HOST_VM_INFO64, (host_info64_t)&vm_stat, &host_size) != KERN_SUCCESS) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint64_t free_memory = (uint64_t)vm_stat.free_count * pagesize;
|
||||
free_memory += (uint64_t)vm_stat.speculative_count * pagesize;
|
||||
free_memory += (uint64_t)vm_stat.inactive_count * pagesize;
|
||||
|
||||
return free_memory;
|
||||
}
|
||||
|
@@ -50,7 +50,7 @@ var OneapiMgmtName = "libze_intel_gpu.so"
|
||||
|
||||
func GetCPUMem() (memInfo, error) {
|
||||
var mem memInfo
|
||||
var total, available, free, buffers, cached uint64
|
||||
var total, available, free, buffers, cached, freeSwap uint64
|
||||
f, err := os.Open("/proc/meminfo")
|
||||
if err != nil {
|
||||
return mem, err
|
||||
@@ -70,20 +70,21 @@ func GetCPUMem() (memInfo, error) {
|
||||
_, err = fmt.Sscanf(line, "Buffers:%d", &buffers)
|
||||
case strings.HasPrefix(line, "Cached:"):
|
||||
_, err = fmt.Sscanf(line, "Cached:%d", &cached)
|
||||
case strings.HasPrefix(line, "SwapFree:"):
|
||||
_, err = fmt.Sscanf(line, "SwapFree:%d", &freeSwap)
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return mem, err
|
||||
}
|
||||
|
||||
if total > 0 && available > 0 {
|
||||
mem.TotalMemory = total * format.KibiByte
|
||||
mem.FreeMemory = available * format.KibiByte
|
||||
return mem, nil
|
||||
}
|
||||
}
|
||||
mem.TotalMemory = total * format.KibiByte
|
||||
mem.FreeMemory = (free + buffers + cached) * format.KibiByte
|
||||
mem.FreeSwap = freeSwap * format.KibiByte
|
||||
if available > 0 {
|
||||
mem.FreeMemory = available * format.KibiByte
|
||||
} else {
|
||||
mem.FreeMemory = (free + buffers + cached) * format.KibiByte
|
||||
}
|
||||
return mem, nil
|
||||
}
|
||||
|
@@ -51,5 +51,5 @@ func GetCPUMem() (memInfo, error) {
|
||||
if r1 == 0 {
|
||||
return memInfo{}, fmt.Errorf("GlobalMemoryStatusEx failed: %w", err)
|
||||
}
|
||||
return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys}, nil
|
||||
return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys, FreeSwap: memStatus.AvailPageFile}, nil
|
||||
}
|
||||
|
@@ -10,6 +10,7 @@ import (
|
||||
type memInfo struct {
|
||||
TotalMemory uint64 `json:"total_memory,omitempty"`
|
||||
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||
FreeSwap uint64 `json:"free_swap,omitempty"`
|
||||
}
|
||||
|
||||
// Beginning of an `ollama info` command
|
||||
@@ -52,7 +53,8 @@ type CPUInfo struct {
|
||||
|
||||
type CudaGPUInfo struct {
|
||||
GpuInfo
|
||||
index int //nolint:unused,nolintlint
|
||||
OSOverhead uint64 // Memory overhead between the driver library and management library
|
||||
index int //nolint:unused,nolintlint
|
||||
}
|
||||
type CudaGPUInfoList []CudaGPUInfo
|
||||
|
||||
|
@@ -69,7 +69,7 @@ func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
||||
reqLimit := len(req)
|
||||
iterLimit := 5
|
||||
|
||||
vram := os.Getenv("OLLAMA_MAX_VRAM")
|
||||
vram := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM
|
||||
if vram != "" {
|
||||
max, err := strconv.ParseUint(vram, 10, 64)
|
||||
require.NoError(t, err)
|
||||
@@ -106,7 +106,7 @@ func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
||||
|
||||
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
|
||||
func TestMultiModelStress(t *testing.T) {
|
||||
vram := os.Getenv("OLLAMA_MAX_VRAM")
|
||||
vram := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM
|
||||
if vram == "" {
|
||||
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
|
||||
}
|
||||
|
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
func TestContextExhaustion(t *testing.T) {
|
||||
// Longer needed for small footprint GPUs
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 6*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
@@ -25,5 +25,10 @@ func TestContextExhaustion(t *testing.T) {
|
||||
"num_ctx": 128,
|
||||
},
|
||||
}
|
||||
GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("PullIfMissing failed: %v", err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
201
integration/embed_test.go
Normal file
201
integration/embed_test.go
Normal file
@@ -0,0 +1,201 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func floatsEqual32(a, b float32) bool {
|
||||
return math.Abs(float64(a-b)) <= 1e-4
|
||||
}
|
||||
|
||||
func floatsEqual64(a, b float64) bool {
|
||||
return math.Abs(a-b) <= 1e-4
|
||||
}
|
||||
|
||||
func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
req := api.EmbeddingRequest{
|
||||
Model: "all-minilm",
|
||||
Prompt: "why is the sky blue?",
|
||||
}
|
||||
|
||||
res, err := embeddingTestHelper(ctx, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if len(res.Embedding) != 384 {
|
||||
t.Fatalf("expected 384 floats, got %d", len(res.Embedding))
|
||||
}
|
||||
|
||||
if !floatsEqual64(res.Embedding[0], 0.06642947345972061) {
|
||||
t.Fatalf("expected 0.06642947345972061, got %.16f", res.Embedding[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllMiniLMEmbed(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if len(res.Embeddings) != 1 {
|
||||
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
|
||||
}
|
||||
|
||||
if len(res.Embeddings[0]) != 384 {
|
||||
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
||||
}
|
||||
|
||||
if !floatsEqual32(res.Embeddings[0][0], 0.010071031) {
|
||||
t.Fatalf("expected 0.010071031, got %.8f", res.Embeddings[0][0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: []string{"why is the sky blue?", "why is the grass green?"},
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if len(res.Embeddings) != 2 {
|
||||
t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings))
|
||||
}
|
||||
|
||||
if len(res.Embeddings[0]) != 384 {
|
||||
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
||||
}
|
||||
|
||||
if !floatsEqual32(res.Embeddings[0][0], 0.010071031) || !floatsEqual32(res.Embeddings[1][0], -0.009802706) {
|
||||
t.Fatalf("expected 0.010071031 and -0.009802706, got %.8f and %.8f", res.Embeddings[0][0], res.Embeddings[1][0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
truncTrue, truncFalse := true, false
|
||||
|
||||
type testReq struct {
|
||||
Name string
|
||||
Request api.EmbedRequest
|
||||
}
|
||||
|
||||
reqs := []testReq{
|
||||
{
|
||||
Name: "Target Truncation",
|
||||
Request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Default Truncate",
|
||||
Request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Options: map[string]any{"num_ctx": 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Explicit Truncate",
|
||||
Request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
res := make(map[string]*api.EmbedResponse)
|
||||
|
||||
for _, req := range reqs {
|
||||
response, err := embedTestHelper(ctx, t, req.Request)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
res[req.Name] = response
|
||||
}
|
||||
|
||||
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
|
||||
t.Fatal("expected default request to truncate correctly")
|
||||
}
|
||||
|
||||
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
|
||||
t.Fatal("expected default request and truncate true request to be the same")
|
||||
}
|
||||
|
||||
// check that truncate set to false returns an error if context length is exceeded
|
||||
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 1},
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||
}
|
||||
|
||||
response, err := client.Embeddings(ctx, &req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||
}
|
||||
|
||||
response, err := client.Embed(ctx, &req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
28
llm/ext_server/CMakeLists.txt
vendored
28
llm/ext_server/CMakeLists.txt
vendored
@@ -1,17 +1,13 @@
|
||||
|
||||
set(TARGET ollama_llama_server)
|
||||
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
|
||||
target_compile_definitions(${TARGET} PRIVATE
|
||||
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
|
||||
)
|
||||
target_link_libraries(${TARGET} PRIVATE ggml llama common llava ${CMAKE_THREAD_LIBS_INIT})
|
||||
install(TARGETS ollama_llama_server ggml llama
|
||||
RUNTIME DESTINATION "${CMAKE_BINARY_DIR}/bin"
|
||||
LIBRARY DESTINATION "${CMAKE_BINARY_DIR}/bin"
|
||||
COMPONENT ollama_llama_server)
|
||||
if (WIN32)
|
||||
TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
|
||||
endif()
|
||||
set(TARGET ollama_llama_server)
|
||||
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_compile_definitions(${TARGET} PRIVATE
|
||||
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
|
||||
)
|
||||
target_link_libraries(${TARGET} PRIVATE ggml llama common llava ${CMAKE_THREAD_LIBS_INIT})
|
||||
if (WIN32)
|
||||
TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
|
||||
endif()
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
68
llm/ext_server/server.cpp
vendored
68
llm/ext_server/server.cpp
vendored
@@ -1413,7 +1413,7 @@ struct llama_server_context
|
||||
return get_slot(-1);
|
||||
}
|
||||
|
||||
LOG_INFO("slot with common prefix found", {{
|
||||
LOG_DEBUG("slot with common prefix found", {{
|
||||
"slot_id", slot->id,
|
||||
"characters", longest
|
||||
}});
|
||||
@@ -1688,22 +1688,8 @@ struct llama_server_context
|
||||
}
|
||||
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
||||
|
||||
char buf[256];
|
||||
llama_model_meta_val_str(model, "general.architecture", buf, 256);
|
||||
bool gemma2 = strcmp(buf, "gemma2") == 0;
|
||||
|
||||
int32_t truncate_at = slot.n_ctx;
|
||||
|
||||
// truncate at 2/3 of the context length for gemma2 models
|
||||
// as they do not support context shifts (from the sliding window implementation).
|
||||
// this way, prompts that almost fit the context length can still generate a full
|
||||
// response without a sudden stop from hitting the context limit
|
||||
if (gemma2) {
|
||||
truncate_at = 2 * slot.n_ctx / 3;
|
||||
}
|
||||
|
||||
// if input prompt is too big, truncate it, if group attention self-extend is disabled
|
||||
if (slot.ga_n == 1 && slot.n_prompt_tokens >= truncate_at)
|
||||
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx)
|
||||
{
|
||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||
const int n_shift = n_left / 2;
|
||||
@@ -1731,19 +1717,6 @@ struct llama_server_context
|
||||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||
}
|
||||
|
||||
// Models with sliding window attention do not work with context shifts, so
|
||||
// limit their prediction to the context length
|
||||
if (gemma2) {
|
||||
int32_t limit = slot.n_ctx - slot.n_prompt_tokens;
|
||||
slot.n_predict = limit;
|
||||
slot.params.n_predict = limit;
|
||||
LOG_INFO("model does not support sliding window, limiting generation", {
|
||||
{"n_ctx", slot.n_ctx},
|
||||
{"n_prompt_tokens", slot.n_prompt_tokens},
|
||||
{"n_predict", slot.n_predict}
|
||||
});
|
||||
}
|
||||
|
||||
if (!slot.params.cache_prompt)
|
||||
{
|
||||
llama_sampling_reset(slot.ctx_sampling);
|
||||
@@ -3215,26 +3188,33 @@ int main(int argc, char **argv) {
|
||||
prompt = "";
|
||||
}
|
||||
|
||||
json image_data;
|
||||
if (body.count("image_data") != 0) {
|
||||
image_data = body["image_data"];
|
||||
}
|
||||
else
|
||||
{
|
||||
image_data = "";
|
||||
if (prompt.size() == 1) {
|
||||
prompt = prompt[0];
|
||||
}
|
||||
|
||||
// create and queue the task
|
||||
const int task_id = llama.queue_tasks.get_new_id();
|
||||
llama.queue_results.add_waiting_task_id(task_id);
|
||||
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, true, -1);
|
||||
json responses;
|
||||
{
|
||||
const int id_task = llama.queue_tasks.get_new_id();
|
||||
llama.queue_results.add_waiting_task_id(id_task);
|
||||
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
|
||||
|
||||
// get the result
|
||||
task_result result = llama.queue_results.recv(task_id);
|
||||
llama.queue_results.remove_waiting_task_id(task_id);
|
||||
// get the result
|
||||
task_result result = llama.queue_results.recv(id_task);
|
||||
llama.queue_results.remove_waiting_task_id(id_task);
|
||||
if (result.error) {
|
||||
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
||||
}
|
||||
|
||||
// send the result
|
||||
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
||||
responses = result.result_json.value("results", std::vector<json>{result.result_json});
|
||||
json embeddings = json::array();
|
||||
for (auto & elem : responses) {
|
||||
embeddings.push_back(elem.at("embedding"));
|
||||
}
|
||||
// send the result
|
||||
json embedding_res = json{{"embedding", embeddings}};
|
||||
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
|
||||
}
|
||||
});
|
||||
|
||||
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
|
||||
|
@@ -81,7 +81,6 @@ apply_patches() {
|
||||
build() {
|
||||
cmake -S ${LLAMACPP_DIR} -B ${BUILD_DIR} ${CMAKE_DEFS}
|
||||
cmake --build ${BUILD_DIR} ${CMAKE_TARGETS} -j8
|
||||
cmake --install ${BUILD_DIR} --component ollama_llama_server
|
||||
}
|
||||
|
||||
compress() {
|
||||
|
@@ -18,7 +18,7 @@ sign() {
|
||||
fi
|
||||
}
|
||||
|
||||
COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DLLAMA_METAL_MACOSX_VERSION_MIN=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DGGML_METAL_EMBED_LIBRARY=on -DGGML_OPENMP=off"
|
||||
COMMON_DARWIN_DEFS="-DBUILD_SHARED_LIBS=off -DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DLLAMA_METAL_MACOSX_VERSION_MIN=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DGGML_METAL_EMBED_LIBRARY=on -DGGML_OPENMP=off"
|
||||
|
||||
case "${GOARCH}" in
|
||||
"amd64")
|
||||
@@ -27,7 +27,7 @@ case "${GOARCH}" in
|
||||
# Static build for linking into the Go binary
|
||||
init_vars
|
||||
CMAKE_TARGETS="--target llama --target ggml"
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DBUILD_SHARED_LIBS=off -DGGML_BLAS=off -DGGML_ACCELERATE=off -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_BLAS=off -DGGML_ACCELERATE=off -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="../build/darwin/${ARCH}_static"
|
||||
echo "Building static library"
|
||||
build
|
||||
@@ -75,7 +75,7 @@ case "${GOARCH}" in
|
||||
# Static build for linking into the Go binary
|
||||
init_vars
|
||||
CMAKE_TARGETS="--target llama --target ggml"
|
||||
CMAKE_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DBUILD_SHARED_LIBS=off -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} ${CMAKE_DEFS}"
|
||||
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} ${CMAKE_DEFS}"
|
||||
BUILD_DIR="../build/darwin/${ARCH}_static"
|
||||
echo "Building static library"
|
||||
build
|
||||
|
@@ -51,7 +51,7 @@ if [ -z "${CUDACXX}" ]; then
|
||||
export CUDACXX=$(command -v nvcc)
|
||||
fi
|
||||
fi
|
||||
COMMON_CMAKE_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DGGML_NATIVE=off -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_OPENMP=off"
|
||||
COMMON_CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DCMAKE_POSITION_INDEPENDENT_CODE=on -DGGML_NATIVE=off -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_OPENMP=off"
|
||||
source $(dirname $0)/gen_common.sh
|
||||
init_vars
|
||||
git_module_setup
|
||||
@@ -77,7 +77,7 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
|
||||
if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then
|
||||
init_vars
|
||||
echo "OLLAMA_CUSTOM_CPU_DEFS=\"${OLLAMA_CUSTOM_CPU_DEFS}\""
|
||||
CMAKE_DEFS="${OLLAMA_CUSTOM_CPU_DEFS} -DCMAKE_POSITION_INDEPENDENT_CODE=on ${CMAKE_DEFS}"
|
||||
CMAKE_DEFS="${OLLAMA_CUSTOM_CPU_DEFS} -DBUILD_SHARED_LIBS=off -DCMAKE_POSITION_INDEPENDENT_CODE=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="../build/linux/${ARCH}/cpu"
|
||||
echo "Building custom CPU"
|
||||
build
|
||||
@@ -93,7 +93,7 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
|
||||
# -DGGML_AVX512_VBMI -- 2018 Intel Cannon Lake
|
||||
# -DGGML_AVX512_VNNI -- 2021 Intel Alder Lake
|
||||
|
||||
COMMON_CPU_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DGGML_NATIVE=off -DGGML_OPENMP=off"
|
||||
COMMON_CPU_DEFS="-DBUILD_SHARED_LIBS=off -DCMAKE_POSITION_INDEPENDENT_CODE=on -DGGML_NATIVE=off -DGGML_OPENMP=off"
|
||||
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu" ]; then
|
||||
#
|
||||
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
|
||||
@@ -178,7 +178,7 @@ if [ -z "${OLLAMA_SKIP_CUDA_GENERATE}" -a -d "${CUDA_LIB_DIR}" ]; then
|
||||
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
|
||||
echo "Building custom CUDA GPU"
|
||||
else
|
||||
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DGGML_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} -DCMAKE_LIBRARY_PATH=/usr/local/cuda/compat"
|
||||
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}"
|
||||
fi
|
||||
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}"
|
||||
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
|
||||
@@ -254,7 +254,7 @@ if [ -z "${OLLAMA_SKIP_ROCM_GENERATE}" -a -d "${ROCM_PATH}" ]; then
|
||||
ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocblas.so.*.*.????? | cut -f5 -d. || true)
|
||||
fi
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
|
||||
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DLLAMA_CUDA_NO_PEER_COPY=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
|
||||
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
|
||||
if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
|
||||
echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""
|
||||
|
@@ -6,18 +6,9 @@ function amdGPUs {
|
||||
if ($env:AMDGPU_TARGETS) {
|
||||
return $env:AMDGPU_TARGETS
|
||||
}
|
||||
# TODO - load from some common data file for linux + windows build consistency
|
||||
# Current supported rocblas list from ROCm v6.1.2 on windows
|
||||
# https://rocm.docs.amd.com/projects/install-on-windows/en/latest/reference/system-requirements.html#windows-supported-gpus
|
||||
$GPU_LIST = @(
|
||||
"gfx900"
|
||||
"gfx906:xnack-"
|
||||
"gfx908:xnack-"
|
||||
"gfx90a:xnack+"
|
||||
"gfx90a:xnack-"
|
||||
"gfx940"
|
||||
"gfx941"
|
||||
"gfx942"
|
||||
"gfx1010"
|
||||
"gfx1012"
|
||||
"gfx1030"
|
||||
"gfx1100"
|
||||
"gfx1101"
|
||||
@@ -366,6 +357,7 @@ function build_rocm() {
|
||||
"-DCMAKE_C_COMPILER=clang.exe",
|
||||
"-DCMAKE_CXX_COMPILER=clang++.exe",
|
||||
"-DGGML_HIPBLAS=on",
|
||||
"-DLLAMA_CUDA_NO_PEER_COPY=on",
|
||||
"-DHIP_PLATFORM=amd",
|
||||
"-DGGML_AVX=on",
|
||||
"-DGGML_AVX2=off",
|
||||
@@ -394,7 +386,6 @@ function build_rocm() {
|
||||
sign
|
||||
install
|
||||
|
||||
# Assumes v5.7, may need adjustments for v6
|
||||
rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
|
||||
md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
|
||||
cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
|
||||
|
26
llm/ggml.go
26
llm/ggml.go
@@ -424,6 +424,32 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
||||
4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
|
||||
4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
|
||||
)
|
||||
case "chatglm":
|
||||
fullOffload = 4 * batch * (embedding + vocab)
|
||||
partialOffload = 4*batch*(embedding+vocab) + embedding*vocab*105/128
|
||||
if qkvBias, ok := layers["blk.0"]["attn_qkv.bias"]; ok {
|
||||
fullOffload = max(
|
||||
fullOffload,
|
||||
4*batch*(2+
|
||||
2*embedding+
|
||||
context+
|
||||
context*heads+
|
||||
embeddingHeadsK*heads+
|
||||
qkvBias.Shape[0]),
|
||||
)
|
||||
|
||||
partialOffload = max(
|
||||
partialOffload,
|
||||
4*batch*(1+
|
||||
2*embedding+
|
||||
embeddingHeadsK*heads+
|
||||
context+
|
||||
context*heads)+
|
||||
4*embeddingHeadsK*context+
|
||||
4*context*embeddingHeadsK+
|
||||
4*qkvBias.Shape[0],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
|
@@ -537,6 +537,7 @@ var ggufKVOrder = map[string][]string{
|
||||
"tokenizer.ggml.add_bos_token",
|
||||
"tokenizer.ggml.add_eos_token",
|
||||
"tokenizer.chat_template",
|
||||
"bert.pooling_type",
|
||||
},
|
||||
}
|
||||
|
||||
|
Submodule llm/llama.cpp updated: d7fd29fff1...d94c6e0ccb
17
llm/llm.go
17
llm/llm.go
@@ -1,12 +1,13 @@
|
||||
package llm
|
||||
|
||||
// #cgo CFLAGS: -Illama.cpp/include -Illama.cpp/ggml/include
|
||||
// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/src/libllama.a ${SRCDIR}/build/darwin/arm64_static/ggml/src/libggml.a -lstdc++ -framework Accelerate -framework Metal
|
||||
// #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/src/libllama.a ${SRCDIR}/build/darwin/x86_64_static/ggml/src/libggml.a -lstdc++ -framework Accelerate -framework Metal
|
||||
// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/src/libllama.a ${SRCDIR}/build/windows/amd64_static/ggml/src/libggml.a -static -lstdc++
|
||||
// #cgo windows,arm64 LDFLAGS: ${SRCDIR}/build/windows/arm64_static/src/libllama.a ${SRCDIR}/build/windows/arm64_static/ggml/src/libggml.a -static -lstdc++
|
||||
// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/src/libllama.a ${SRCDIR}/build/linux/x86_64_static/ggml/src/libggml.a -lstdc++
|
||||
// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/src/libllama.a ${SRCDIR}/build/linux/arm64_static/ggml/src/libggml.a -lstdc++
|
||||
// #cgo CFLAGS: -Illama.cpp -Illama.cpp/include -Illama.cpp/ggml/include
|
||||
// #cgo LDFLAGS: -lllama -lggml -lstdc++ -lpthread
|
||||
// #cgo darwin,arm64 LDFLAGS: -L${SRCDIR}/build/darwin/arm64_static -L${SRCDIR}/build/darwin/arm64_static/src -L${SRCDIR}/build/darwin/arm64_static/ggml/src -framework Accelerate -framework Metal
|
||||
// #cgo darwin,amd64 LDFLAGS: -L${SRCDIR}/build/darwin/x86_64_static -L${SRCDIR}/build/darwin/x86_64_static/src -L${SRCDIR}/build/darwin/x86_64_static/ggml/src
|
||||
// #cgo windows,amd64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/amd64_static -L${SRCDIR}/build/windows/amd64_static/src -L${SRCDIR}/build/windows/amd64_static/ggml/src
|
||||
// #cgo windows,arm64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/arm64_static -L${SRCDIR}/build/windows/arm64_static/src -L${SRCDIR}/build/windows/arm64_static/ggml/src
|
||||
// #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src
|
||||
// #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
|
||||
// #include <stdlib.h>
|
||||
// #include "llama.h"
|
||||
import "C"
|
||||
@@ -32,7 +33,7 @@ func Quantize(infile, outfile string, ftype fileType) error {
|
||||
params.ftype = ftype.Value()
|
||||
|
||||
if rc := C.llama_model_quantize(cinfile, coutfile, ¶ms); rc != 0 {
|
||||
return fmt.Errorf("llama_model_quantize: %d", rc)
|
||||
return fmt.Errorf("failed to quantize model. This model architecture may not be supported, or you may need to upgrade Ollama to the latest version")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@@ -1,11 +1,11 @@
|
||||
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||
index 73f52435..2b81b4bd 100644
|
||||
index 8fe51971..7113ba64 100644
|
||||
--- a/src/llama.cpp
|
||||
+++ b/src/llama.cpp
|
||||
@@ -5092,16 +5092,7 @@ static void llm_load_vocab(
|
||||
|
||||
// for now, only BPE models have pre-tokenizers
|
||||
@@ -5433,16 +5433,7 @@ static void llm_load_vocab(
|
||||
if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
|
||||
vocab.tokenizer_add_space_prefix = false;
|
||||
vocab.tokenizer_clean_spaces = true;
|
||||
- if (tokenizer_pre.empty()) {
|
||||
- LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
|
||||
- LLAMA_LOG_WARN("%s: \n", __func__);
|
||||
@@ -20,9 +20,9 @@ index 73f52435..2b81b4bd 100644
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||
} else if (
|
||||
tokenizer_pre == "llama3" ||
|
||||
@@ -5164,7 +5155,8 @@ static void llm_load_vocab(
|
||||
tokenizer_pre == "jais") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
|
||||
@@ -5526,7 +5517,8 @@ static void llm_load_vocab(
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMOLLM;
|
||||
vocab.tokenizer_clean_spaces = false;
|
||||
} else {
|
||||
- throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||
+ LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
|
||||
|
@@ -1,13 +0,0 @@
|
||||
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||
index 40d2ec2c..f34eb79a 100644
|
||||
--- a/src/llama.cpp
|
||||
+++ b/src/llama.cpp
|
||||
@@ -6943,7 +6943,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
|
||||
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
|
||||
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
|
||||
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
360
llm/patches/09-lora.diff
Normal file
360
llm/patches/09-lora.diff
Normal file
@@ -0,0 +1,360 @@
|
||||
diff --git a/common/common.cpp b/common/common.cpp
|
||||
index dbb724fb..c26fe6ee 100644
|
||||
--- a/common/common.cpp
|
||||
+++ b/common/common.cpp
|
||||
@@ -2087,14 +2087,29 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
||||
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
|
||||
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
|
||||
float lora_scale = std::get<1>(params.lora_adapter[i]);
|
||||
+
|
||||
+ // try to load as gguf
|
||||
auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str());
|
||||
if (adapter == nullptr) {
|
||||
- fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
|
||||
- llama_free(lctx);
|
||||
- llama_free_model(model);
|
||||
- return std::make_tuple(nullptr, nullptr);
|
||||
+ fprintf(stderr, "%s: error: failed to apply lora adapter, trying ggla\n", __func__);
|
||||
+
|
||||
+ // if that fails, try loading as ggla for compatibility
|
||||
+ int err = llama_model_apply_lora_from_file(model,
|
||||
+ lora_adapter.c_str(),
|
||||
+ lora_scale,
|
||||
+ ((i > 0) || params.lora_base.empty())
|
||||
+ ? NULL
|
||||
+ : params.lora_base.c_str(),
|
||||
+ params.n_threads);
|
||||
+ if (err != 0) {
|
||||
+ fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
|
||||
+ llama_free(lctx);
|
||||
+ llama_free_model(model);
|
||||
+ return std::make_tuple(nullptr, nullptr);
|
||||
+ }
|
||||
+ } else {
|
||||
+ llama_lora_adapter_set(lctx, adapter, lora_scale);
|
||||
}
|
||||
- llama_lora_adapter_set(lctx, adapter, lora_scale);
|
||||
}
|
||||
|
||||
if (params.ignore_eos) {
|
||||
diff --git a/include/llama.h b/include/llama.h
|
||||
index 93fd77ca..b0fb37a6 100644
|
||||
--- a/include/llama.h
|
||||
+++ b/include/llama.h
|
||||
@@ -1160,6 +1160,20 @@ extern "C" {
|
||||
|
||||
LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
|
||||
|
||||
+ // Apply a LoRA adapter to a loaded model
|
||||
+ // path_base_model is the path to a higher quality model to use as a base for
|
||||
+ // the layers modified by the adapter. Can be NULL to use the current loaded model.
|
||||
+ // The model needs to be reloaded before applying a new adapter, otherwise the adapter
|
||||
+ // will be applied on top of the previous one
|
||||
+ // Returns 0 on success
|
||||
+ LLAMA_API int32_t llama_model_apply_lora_from_file(
|
||||
+ const struct llama_model * model,
|
||||
+ const char * path_lora,
|
||||
+ float scale,
|
||||
+ const char * path_base_model,
|
||||
+ int32_t n_threads);
|
||||
+
|
||||
+
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||
index 80a0dd0f..9d7b0e17 100644
|
||||
--- a/src/llama.cpp
|
||||
+++ b/src/llama.cpp
|
||||
@@ -21880,3 +21880,290 @@ static void llama_log_callback_default(ggml_log_level level, const char * text,
|
||||
fputs(text, stderr);
|
||||
fflush(stderr);
|
||||
}
|
||||
+
|
||||
+static int llama_apply_lora_from_file_internal(
|
||||
+ const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads
|
||||
+) {
|
||||
+ LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);
|
||||
+
|
||||
+ const int64_t t_start_lora_us = ggml_time_us();
|
||||
+
|
||||
+ llama_file fin(path_lora, "rb");
|
||||
+
|
||||
+ // verify magic and version
|
||||
+ {
|
||||
+ uint32_t magic = fin.read_u32();
|
||||
+ if (magic != LLAMA_FILE_MAGIC_GGLA) {
|
||||
+ LLAMA_LOG_ERROR("%s: bad file magic\n", __func__);
|
||||
+ return 1;
|
||||
+ }
|
||||
+
|
||||
+ uint32_t format_version = fin.read_u32();
|
||||
+ if (format_version != 1) {
|
||||
+ LLAMA_LOG_ERROR("%s: unsupported file version\n", __func__ );
|
||||
+ return 1;
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ int32_t lora_r = fin.read_u32();
|
||||
+ int32_t lora_alpha = fin.read_u32();
|
||||
+ float scaling = scale * (float)lora_alpha / (float)lora_r;
|
||||
+
|
||||
+ LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
|
||||
+
|
||||
+ // load base model
|
||||
+ std::unique_ptr<llama_model_loader> ml;
|
||||
+ if (path_base_model) {
|
||||
+ LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
|
||||
+ ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*check_tensors*/ false, /*kv_overrides*/ nullptr));
|
||||
+ ml->init_mappings(/*prefetch*/ false); // no prefetching
|
||||
+ }
|
||||
+
|
||||
+ struct tensor_meta {
|
||||
+ std::string name;
|
||||
+ ggml_type type;
|
||||
+ int32_t ne[2];
|
||||
+ size_t offset;
|
||||
+ };
|
||||
+ std::map<std::string, tensor_meta> tensor_meta_map;
|
||||
+
|
||||
+ // load all tensor meta
|
||||
+ while (true) {
|
||||
+ if (fin.tell() == fin.size) {
|
||||
+ // eof
|
||||
+ break;
|
||||
+ }
|
||||
+
|
||||
+ int32_t n_dims;
|
||||
+ int32_t name_len;
|
||||
+ int32_t ftype;
|
||||
+
|
||||
+ fin.read_raw(&n_dims, sizeof(n_dims));
|
||||
+ fin.read_raw(&name_len, sizeof(name_len));
|
||||
+ fin.read_raw(&ftype, sizeof(ftype));
|
||||
+
|
||||
+ if (n_dims != 1 && n_dims != 2) {
|
||||
+ LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims);
|
||||
+ return 1;
|
||||
+ }
|
||||
+
|
||||
+ int32_t ne[2] = { 1, 1 };
|
||||
+ for (int i = 0; i < n_dims; ++i) {
|
||||
+ fin.read_raw(&ne[i], sizeof(ne[i]));
|
||||
+ }
|
||||
+
|
||||
+ std::string name;
|
||||
+ {
|
||||
+ GGML_ASSERT(name_len < GGML_MAX_NAME);
|
||||
+ char buf[GGML_MAX_NAME];
|
||||
+ fin.read_raw(buf, name_len);
|
||||
+ name = std::string(buf, name_len);
|
||||
+ }
|
||||
+
|
||||
+ // check for lora suffix
|
||||
+ std::string lora_suffix;
|
||||
+ if (name.length() > 6) {
|
||||
+ lora_suffix = name.substr(name.length() - 6);
|
||||
+ }
|
||||
+ if (lora_suffix != ".loraA" && lora_suffix != ".loraB") {
|
||||
+ LLAMA_LOG_ERROR("%s: error: '%s' is not a lora tensor\n", __func__, name.c_str());
|
||||
+ return 1;
|
||||
+ }
|
||||
+
|
||||
+ // tensor type
|
||||
+ ggml_type wtype;
|
||||
+ switch (ftype) {
|
||||
+ case 0: wtype = GGML_TYPE_F32; break;
|
||||
+ case 1: wtype = GGML_TYPE_F16; break;
|
||||
+ default:
|
||||
+ {
|
||||
+ LLAMA_LOG_ERROR("%s: invalid tensor data type '%d'\n",
|
||||
+ __func__, ftype);
|
||||
+ return 1;
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ // data offset
|
||||
+ size_t offset = fin.tell();
|
||||
+ offset = (offset + 31) & -32;
|
||||
+
|
||||
+ // skip tensor data
|
||||
+ fin.seek(offset + ggml_row_size(wtype, ne[0]) * ne[1], SEEK_SET);
|
||||
+
|
||||
+ tensor_meta_map.emplace(name, tensor_meta{ name, wtype, { ne[0], ne[1] }, offset });
|
||||
+ }
|
||||
+
|
||||
+ bool warned = false;
|
||||
+ int n_tensors = 0;
|
||||
+
|
||||
+ // apply
|
||||
+ ggml_backend_t backend_cpu = ggml_backend_cpu_init();
|
||||
+ if (backend_cpu == nullptr) {
|
||||
+ LLAMA_LOG_ERROR("%s: error: failed to initialize cpu backend\n", __func__);
|
||||
+ return 1;
|
||||
+ }
|
||||
+ ggml_backend_cpu_set_n_threads(backend_cpu, n_threads);
|
||||
+
|
||||
+ std::vector<no_init<uint8_t>> read_buf;
|
||||
+ for (const auto & it : model.tensors_by_name) {
|
||||
+ const std::string & base_name = it.first;
|
||||
+ ggml_tensor * model_t = it.second;
|
||||
+
|
||||
+ if (tensor_meta_map.find(base_name + ".loraA") == tensor_meta_map.end() ||
|
||||
+ tensor_meta_map.find(base_name + ".loraB") == tensor_meta_map.end()) {
|
||||
+ continue;
|
||||
+ }
|
||||
+
|
||||
+ tensor_meta & metaA = tensor_meta_map.at(base_name + ".loraA");
|
||||
+ tensor_meta & metaB = tensor_meta_map.at(base_name + ".loraB");
|
||||
+
|
||||
+ ggml_init_params lora_init_params = {
|
||||
+ /* .mem_size */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
|
||||
+ /* .mem_buffer */ nullptr,
|
||||
+ /* .no_alloc */ true,
|
||||
+ };
|
||||
+ ggml_context * lora_ctx = ggml_init(lora_init_params);
|
||||
+ if (lora_ctx == nullptr) {
|
||||
+ LLAMA_LOG_ERROR("%s: error: failed to initialize lora context\n", __func__);
|
||||
+ ggml_backend_free(backend_cpu);
|
||||
+ return 1;
|
||||
+ }
|
||||
+
|
||||
+ // create tensors
|
||||
+ ggml_tensor * loraA = ggml_new_tensor_2d(lora_ctx, metaA.type, metaA.ne[0], metaA.ne[1]);
|
||||
+ ggml_tensor * loraB = ggml_new_tensor_2d(lora_ctx, metaB.type, metaB.ne[0], metaB.ne[1]);
|
||||
+ ggml_set_name(loraA, metaA.name.c_str());
|
||||
+ ggml_set_name(loraB, metaB.name.c_str());
|
||||
+
|
||||
+ ggml_tensor * base_t;
|
||||
+ if (ml) {
|
||||
+ if (!ml->get_tensor_meta(base_name.c_str())) {
|
||||
+ LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
|
||||
+ return 1;
|
||||
+ }
|
||||
+ base_t = ggml_dup_tensor(lora_ctx, ml->get_tensor_meta(base_name.c_str()));
|
||||
+ } else {
|
||||
+ base_t = ggml_dup_tensor(lora_ctx, model_t);
|
||||
+ }
|
||||
+ ggml_set_name(base_t, base_name.c_str());
|
||||
+
|
||||
+ // allocate in backend buffer
|
||||
+ ggml_backend_buffer_t lora_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type());
|
||||
+ if (lora_buf == nullptr) {
|
||||
+ LLAMA_LOG_ERROR("%s: error: failed to allocate lora tensors\n", __func__);
|
||||
+ return 1;
|
||||
+ }
|
||||
+
|
||||
+ // load tensor data
|
||||
+ auto load_tensor = [&read_buf, &fin](const tensor_meta & tensor_meta, ggml_tensor * tensor) {
|
||||
+ read_buf.resize(ggml_nbytes(tensor));
|
||||
+ fin.seek(tensor_meta.offset, SEEK_SET);
|
||||
+ fin.read_raw(read_buf.data(), ggml_nbytes(tensor));
|
||||
+ ggml_backend_tensor_set(tensor, read_buf.data(), 0, read_buf.size());
|
||||
+ };
|
||||
+ load_tensor(metaA, loraA);
|
||||
+ load_tensor(metaB, loraB);
|
||||
+
|
||||
+ // load base model tensor data
|
||||
+ if (ml) {
|
||||
+ ml->load_data_for(base_t);
|
||||
+ } else {
|
||||
+ ggml_backend_tensor_copy(model_t, base_t);
|
||||
+ }
|
||||
+
|
||||
+ if (ggml_is_quantized(base_t->type) && !warned) {
|
||||
+ LLAMA_LOG_WARN("%s: warning: using a lora adapter with a quantized model may result in poor quality, "
|
||||
+ "use a f16 or f32 base model with --lora-base\n", __func__);
|
||||
+ warned = true;
|
||||
+ }
|
||||
+
|
||||
+ if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) {
|
||||
+ LLAMA_LOG_ERROR("%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
|
||||
+ " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]);
|
||||
+ ggml_free(lora_ctx);
|
||||
+ ggml_backend_buffer_free(lora_buf);
|
||||
+ ggml_backend_free(backend_cpu);
|
||||
+ return 1;
|
||||
+ }
|
||||
+
|
||||
+ auto build_lora_graph = [&]() {
|
||||
+ // w = w + BA*s
|
||||
+ ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB);
|
||||
+ ggml_set_name(BA, "BA");
|
||||
+
|
||||
+ if (scaling != 1.0f) {
|
||||
+ BA = ggml_scale(lora_ctx, BA, scaling);
|
||||
+ ggml_set_name(BA, "BA_scaled");
|
||||
+ }
|
||||
+
|
||||
+ ggml_tensor * r;
|
||||
+ r = ggml_add_inplace(lora_ctx, base_t, BA);
|
||||
+ ggml_set_name(r, "r_add");
|
||||
+
|
||||
+ if (base_t->type != model_t->type) {
|
||||
+ // convert the result to the model type
|
||||
+ r = ggml_cast(lora_ctx, r, model_t->type);
|
||||
+ ggml_set_name(r, "r_cast");
|
||||
+ }
|
||||
+
|
||||
+ return r;
|
||||
+ };
|
||||
+
|
||||
+ ggml_cgraph * gf = ggml_new_graph(lora_ctx);
|
||||
+ ggml_tensor * r = build_lora_graph();
|
||||
+ ggml_build_forward_expand(gf, r);
|
||||
+
|
||||
+ ggml_backend_buffer_t graph_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type());
|
||||
+ if (graph_buf == nullptr) {
|
||||
+ LLAMA_LOG_ERROR("%s: error: failed to allocate graph tensors\n", __func__);
|
||||
+ ggml_free(lora_ctx);
|
||||
+ ggml_backend_buffer_free(lora_buf);
|
||||
+ ggml_backend_free(backend_cpu);
|
||||
+ return 1;
|
||||
+ }
|
||||
+
|
||||
+ ggml_backend_graph_compute(backend_cpu, gf);
|
||||
+
|
||||
+ ggml_backend_tensor_set(model_t, r->data, 0, ggml_nbytes(r));
|
||||
+
|
||||
+#if 0
|
||||
+ // TODO: use scheduler with fallback to CPU for less copies between CPU and GPU
|
||||
+ //ggml_backend_sched_t sched = ggml_backend_sched_new(backends.data(), backends.size(), GGML_DEFAULT_GRAPH_SIZE);
|
||||
+
|
||||
+ // sched compute
|
||||
+ ggml_build_forward_expand(gf, build_graph());
|
||||
+ ggml_backend_sched_init_measure(sched, gf);
|
||||
+
|
||||
+ // create the graph again, since the previous one was destroyed by the measure
|
||||
+ ggml_graph_clear(gf);
|
||||
+ ggml_build_forward_expand(gf, build_graph());
|
||||
+ ggml_backend_sched_graph_compute(sched, gf);
|
||||
+ ggml_backend_sched_free(sched);
|
||||
+#endif
|
||||
+
|
||||
+ ggml_backend_buffer_free(lora_buf);
|
||||
+ ggml_backend_buffer_free(graph_buf);
|
||||
+ ggml_free(lora_ctx);
|
||||
+
|
||||
+ n_tensors++;
|
||||
+ if (n_tensors % 4 == 0) {
|
||||
+ LLAMA_LOG_INFO(".");
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ ggml_backend_free(backend_cpu);
|
||||
+
|
||||
+ const int64_t t_lora_us = ggml_time_us() - t_start_lora_us;
|
||||
+ LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0);
|
||||
+
|
||||
+ return 0;
|
||||
+}
|
||||
+
|
||||
+int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, float scale, const char * path_base_model, int32_t n_threads) {
|
||||
+ try {
|
||||
+ return llama_apply_lora_from_file_internal(*model, path_lora, scale, path_base_model, n_threads);
|
||||
+ } catch (const std::exception & err) {
|
||||
+ LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
|
||||
+ return 1;
|
||||
+ }
|
||||
+}
|
||||
\ No newline at end of file
|
@@ -33,7 +33,7 @@ type LlamaServer interface {
|
||||
Ping(ctx context.Context) error
|
||||
WaitUntilRunning(ctx context.Context) error
|
||||
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||
Embedding(ctx context.Context, prompt string) ([]float64, error)
|
||||
Embed(ctx context.Context, input []string) ([][]float32, error)
|
||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||
Close() error
|
||||
@@ -88,6 +88,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
var estimate MemoryEstimate
|
||||
var systemTotalMemory uint64
|
||||
var systemFreeMemory uint64
|
||||
var systemSwapFreeMemory uint64
|
||||
|
||||
systemMemInfo, err := gpu.GetCPUMem()
|
||||
if err != nil {
|
||||
@@ -95,7 +96,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
} else {
|
||||
systemTotalMemory = systemMemInfo.TotalMemory
|
||||
systemFreeMemory = systemMemInfo.FreeMemory
|
||||
slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", systemFreeMemory)
|
||||
systemSwapFreeMemory = systemMemInfo.FreeSwap
|
||||
slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
|
||||
}
|
||||
|
||||
// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
|
||||
@@ -122,6 +124,16 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
}
|
||||
}
|
||||
|
||||
// On linux, over-allocating CPU memory will almost always result in an error
|
||||
if runtime.GOOS == "linux" {
|
||||
systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
|
||||
available := systemFreeMemory + systemSwapFreeMemory
|
||||
if systemMemoryRequired > available {
|
||||
slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", available, "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "swap", format.HumanBytes2(systemSwapFreeMemory))
|
||||
return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))
|
||||
}
|
||||
}
|
||||
|
||||
estimate.log()
|
||||
|
||||
// Loop through potential servers
|
||||
@@ -254,10 +266,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
params = append(params, "--tensor-split", estimate.TensorSplit)
|
||||
}
|
||||
|
||||
if estimate.TensorSplit != "" {
|
||||
params = append(params, "--tensor-split", estimate.TensorSplit)
|
||||
}
|
||||
|
||||
for i := range len(servers) {
|
||||
dir := availableServers[servers[i]]
|
||||
if dir == "" {
|
||||
@@ -377,8 +385,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
filteredEnv := []string{}
|
||||
for _, ev := range s.cmd.Env {
|
||||
if strings.HasPrefix(ev, "CUDA_") ||
|
||||
strings.HasPrefix(ev, "ROCR_") ||
|
||||
strings.HasPrefix(ev, "ROCM_") ||
|
||||
strings.HasPrefix(ev, "HIP_") ||
|
||||
strings.HasPrefix(ev, "GPU_") ||
|
||||
strings.HasPrefix(ev, "HSA_") ||
|
||||
strings.HasPrefix(ev, "GGML_") ||
|
||||
strings.HasPrefix(ev, "PATH=") ||
|
||||
@@ -407,7 +417,17 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
|
||||
// reap subprocess when it exits
|
||||
go func() {
|
||||
s.done <- s.cmd.Wait()
|
||||
err := s.cmd.Wait()
|
||||
// Favor a more detailed message over the process exit status
|
||||
if err != nil && s.status != nil && s.status.LastErrMsg != "" {
|
||||
slog.Debug("llama runner terminated", "error", err)
|
||||
if strings.Contains(s.status.LastErrMsg, "unknown model") {
|
||||
s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
|
||||
}
|
||||
s.done <- fmt.Errorf(s.status.LastErrMsg)
|
||||
} else {
|
||||
s.done <- err
|
||||
}
|
||||
}()
|
||||
|
||||
return s, nil
|
||||
@@ -570,14 +590,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||
slog.Warn("client connection closed before server finished loading, aborting load")
|
||||
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
|
||||
case err := <-s.done:
|
||||
msg := ""
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
}
|
||||
if strings.Contains(msg, "unknown model") {
|
||||
return fmt.Errorf("this model is not supported by your version of Ollama. You may need to upgrade")
|
||||
}
|
||||
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
|
||||
return fmt.Errorf("llama runner process has terminated: %w", err)
|
||||
default:
|
||||
}
|
||||
if time.Now().After(stallTimer) {
|
||||
@@ -679,7 +692,7 @@ type CompletionRequest struct {
|
||||
Prompt string
|
||||
Format string
|
||||
Images []ImageData
|
||||
Options api.Options
|
||||
Options *api.Options
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
@@ -699,10 +712,9 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
}
|
||||
defer s.sem.Release(1)
|
||||
|
||||
// only allow maximum 10 "context shifts" to avoid infinite generation
|
||||
// put an upper limit on num_predict to avoid the model running on forever
|
||||
if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
|
||||
req.Options.NumPredict = 10 * s.options.NumCtx
|
||||
slog.Debug("setting token limit to 10x num_ctx", "num_ctx", s.options.NumCtx, "num_predict", req.Options.NumPredict)
|
||||
}
|
||||
|
||||
request := map[string]any{
|
||||
@@ -860,15 +872,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
return nil
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Content string `json:"content"`
|
||||
type EmbedRequest struct {
|
||||
Content []string `json:"content"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
type EmbedResponse struct {
|
||||
Embedding [][]float32 `json:"embedding"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
||||
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
|
||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
return nil, err
|
||||
@@ -883,7 +895,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(TokenizeRequest{Content: prompt})
|
||||
data, err := json.Marshal(EmbedRequest{Content: input})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
}
|
||||
@@ -910,7 +922,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
|
||||
return nil, fmt.Errorf("%s", body)
|
||||
}
|
||||
|
||||
var embedding EmbeddingResponse
|
||||
var embedding EmbedResponse
|
||||
if err := json.Unmarshal(body, &embedding); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||
}
|
||||
|
258
openai/openai.go
258
openai/openai.go
@@ -3,11 +3,14 @@ package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -27,8 +30,9 @@ type ErrorResponse struct {
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
type Choice struct {
|
||||
@@ -59,6 +63,11 @@ type ResponseFormat struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type EmbedRequest struct {
|
||||
Input any `json:"input"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
@@ -71,6 +80,7 @@ type ChatCompletionRequest struct {
|
||||
PresencePenalty *float64 `json:"presence_penalty_penalty"`
|
||||
TopP *float64 `json:"top_p"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format"`
|
||||
Tools []api.Tool `json:"tools"`
|
||||
}
|
||||
|
||||
type ChatCompletion struct {
|
||||
@@ -104,6 +114,7 @@ type CompletionRequest struct {
|
||||
Stream bool `json:"stream"`
|
||||
Temperature *float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
Suffix string `json:"suffix"`
|
||||
}
|
||||
|
||||
type Completion struct {
|
||||
@@ -125,6 +136,15 @@ type CompletionChunk struct {
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
} `json:"function"`
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
@@ -132,11 +152,23 @@ type Model struct {
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float32 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type ListCompletion struct {
|
||||
Object string `json:"object"`
|
||||
Data []Model `json:"data"`
|
||||
}
|
||||
|
||||
type EmbeddingList struct {
|
||||
Object string `json:"object"`
|
||||
Data []Embedding `json:"data"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
func NewError(code int, message string) ErrorResponse {
|
||||
var etype string
|
||||
switch code {
|
||||
@@ -151,7 +183,31 @@ func NewError(code int, message string) ErrorResponse {
|
||||
return ErrorResponse{Error{Type: etype, Message: message}}
|
||||
}
|
||||
|
||||
func toolCallId() string {
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, 8)
|
||||
for i := range b {
|
||||
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
||||
}
|
||||
return "call_" + strings.ToLower(string(b))
|
||||
}
|
||||
|
||||
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
|
||||
for i, tc := range r.Message.ToolCalls {
|
||||
toolCalls[i].ID = toolCallId()
|
||||
toolCalls[i].Type = "function"
|
||||
toolCalls[i].Function.Name = tc.Function.Name
|
||||
|
||||
args, err := json.Marshal(tc.Function.Arguments)
|
||||
if err != nil {
|
||||
slog.Error("could not marshall function arguments to json", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
toolCalls[i].Function.Arguments = string(args)
|
||||
}
|
||||
|
||||
return ChatCompletion{
|
||||
Id: id,
|
||||
Object: "chat.completion",
|
||||
@@ -160,7 +216,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
SystemFingerprint: "fp_ollama",
|
||||
Choices: []Choice{{
|
||||
Index: 0,
|
||||
Message: Message{Role: r.Message.Role, Content: r.Message.Content},
|
||||
Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
|
||||
FinishReason: func(reason string) *string {
|
||||
if len(reason) > 0 {
|
||||
return &reason
|
||||
@@ -169,7 +225,6 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
}(r.DoneReason),
|
||||
}},
|
||||
Usage: Usage{
|
||||
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
CompletionTokens: r.EvalCount,
|
||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||
@@ -215,7 +270,6 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
|
||||
}(r.DoneReason),
|
||||
}},
|
||||
Usage: Usage{
|
||||
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
CompletionTokens: r.EvalCount,
|
||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||
@@ -260,6 +314,27 @@ func toListCompletion(r api.ListResponse) ListCompletion {
|
||||
}
|
||||
}
|
||||
|
||||
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
||||
if r.Embeddings != nil {
|
||||
var data []Embedding
|
||||
for i, e := range r.Embeddings {
|
||||
data = append(data, Embedding{
|
||||
Object: "embedding",
|
||||
Embedding: e,
|
||||
Index: i,
|
||||
})
|
||||
}
|
||||
|
||||
return EmbeddingList{
|
||||
Object: "list",
|
||||
Data: data,
|
||||
Model: model,
|
||||
}
|
||||
}
|
||||
|
||||
return EmbeddingList{}
|
||||
}
|
||||
|
||||
func toModel(r api.ShowResponse, m string) Model {
|
||||
return Model{
|
||||
Id: m,
|
||||
@@ -269,10 +344,77 @@ func toModel(r api.ShowResponse, m string) Model {
|
||||
}
|
||||
}
|
||||
|
||||
func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
|
||||
func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
var messages []api.Message
|
||||
for _, msg := range r.Messages {
|
||||
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
|
||||
switch content := msg.Content.(type) {
|
||||
case string:
|
||||
messages = append(messages, api.Message{Role: msg.Role, Content: content})
|
||||
case []any:
|
||||
for _, c := range content {
|
||||
data, ok := c.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid message format")
|
||||
}
|
||||
switch data["type"] {
|
||||
case "text":
|
||||
text, ok := data["text"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid message format")
|
||||
}
|
||||
messages = append(messages, api.Message{Role: msg.Role, Content: text})
|
||||
case "image_url":
|
||||
var url string
|
||||
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
||||
if url, ok = urlMap["url"].(string); !ok {
|
||||
return nil, fmt.Errorf("invalid message format")
|
||||
}
|
||||
} else {
|
||||
if url, ok = data["image_url"].(string); !ok {
|
||||
return nil, fmt.Errorf("invalid message format")
|
||||
}
|
||||
}
|
||||
|
||||
types := []string{"jpeg", "jpg", "png"}
|
||||
valid := false
|
||||
for _, t := range types {
|
||||
prefix := "data:image/" + t + ";base64,"
|
||||
if strings.HasPrefix(url, prefix) {
|
||||
url = strings.TrimPrefix(url, prefix)
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !valid {
|
||||
return nil, fmt.Errorf("invalid image input")
|
||||
}
|
||||
|
||||
img, err := base64.StdEncoding.DecodeString(url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid message format")
|
||||
}
|
||||
|
||||
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid message format")
|
||||
}
|
||||
}
|
||||
default:
|
||||
if msg.ToolCalls == nil {
|
||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||
}
|
||||
|
||||
toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
|
||||
for i, tc := range msg.ToolCalls {
|
||||
toolCalls[i].Function.Name = tc.Function.Name
|
||||
err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid tool call arguments")
|
||||
}
|
||||
}
|
||||
messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
|
||||
}
|
||||
}
|
||||
|
||||
options := make(map[string]interface{})
|
||||
@@ -323,13 +465,14 @@ func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
|
||||
format = "json"
|
||||
}
|
||||
|
||||
return api.ChatRequest{
|
||||
return &api.ChatRequest{
|
||||
Model: r.Model,
|
||||
Messages: messages,
|
||||
Format: format,
|
||||
Options: options,
|
||||
Stream: &r.Stream,
|
||||
}
|
||||
Tools: r.Tools,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
@@ -338,12 +481,16 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
switch stop := r.Stop.(type) {
|
||||
case string:
|
||||
options["stop"] = []string{stop}
|
||||
case []string:
|
||||
options["stop"] = stop
|
||||
default:
|
||||
if r.Stop != nil {
|
||||
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", r.Stop)
|
||||
case []any:
|
||||
var stops []string
|
||||
for _, s := range stop {
|
||||
if str, ok := s.(string); ok {
|
||||
stops = append(stops, str)
|
||||
} else {
|
||||
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
|
||||
}
|
||||
}
|
||||
options["stop"] = stops
|
||||
}
|
||||
|
||||
if r.MaxTokens != nil {
|
||||
@@ -375,6 +522,7 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
Prompt: r.Prompt,
|
||||
Options: options,
|
||||
Stream: &r.Stream,
|
||||
Suffix: r.Suffix,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -403,6 +551,11 @@ type RetrieveWriter struct {
|
||||
model string
|
||||
}
|
||||
|
||||
type EmbedWriter struct {
|
||||
BaseWriter
|
||||
model string
|
||||
}
|
||||
|
||||
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
|
||||
var serr api.StatusError
|
||||
err := json.Unmarshal(data, &serr)
|
||||
@@ -568,6 +721,33 @@ func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
||||
var embedResponse api.EmbedResponse
|
||||
err := json.Unmarshal(data, &embedResponse)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(code, data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func ListMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
w := &ListWriter{
|
||||
@@ -631,6 +811,47 @@ func CompletionsMiddleware() gin.HandlerFunc {
|
||||
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func EmbeddingsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req EmbedRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Input == "" {
|
||||
req.Input = []string{""}
|
||||
}
|
||||
|
||||
if req.Input == nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
||||
return
|
||||
}
|
||||
|
||||
if v, ok := req.Input.([]any); ok && len(v) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &EmbedWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
model: req.Model,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
@@ -652,7 +873,14 @@ func ChatMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
|
||||
|
||||
chatReq, err := fromChatRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
@@ -2,8 +2,8 @@ package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -16,7 +16,387 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMiddleware(t *testing.T) {
|
||||
const prefix = `data:image/jpeg;base64,`
|
||||
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||
const imageURL = prefix + image
|
||||
|
||||
func prepareRequest(req *http.Request, body any) {
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
err := json.Unmarshal(bodyBytes, capturedRequest)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
Name string
|
||||
Setup func(t *testing.T, req *http.Request)
|
||||
Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
|
||||
}
|
||||
|
||||
var capturedRequest *api.ChatRequest
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
Name: "chat handler",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{{Role: "user", Content: "Hello"}},
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", resp.Code)
|
||||
}
|
||||
|
||||
if req.Messages[0].Role != "user" {
|
||||
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
|
||||
}
|
||||
|
||||
if req.Messages[0].Content != "Hello" {
|
||||
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "chat handler with image content",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{
|
||||
{
|
||||
Role: "user", Content: []map[string]any{
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", resp.Code)
|
||||
}
|
||||
|
||||
if req.Messages[0].Role != "user" {
|
||||
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
|
||||
}
|
||||
|
||||
if req.Messages[0].Content != "Hello" {
|
||||
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
|
||||
}
|
||||
|
||||
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
|
||||
|
||||
if req.Messages[1].Role != "user" {
|
||||
t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
|
||||
}
|
||||
|
||||
if !bytes.Equal(req.Messages[1].Images[0], img) {
|
||||
t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "chat handler with tools",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{
|
||||
{Role: "user", Content: "What's the weather like in Paris Today?"},
|
||||
{Role: "assistant", ToolCalls: []ToolCall{{
|
||||
ID: "id",
|
||||
Type: "function",
|
||||
Function: struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}{
|
||||
Name: "get_current_weather",
|
||||
Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
|
||||
},
|
||||
}}},
|
||||
},
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||
if resp.Code != 200 {
|
||||
t.Fatalf("expected 200, got %d", resp.Code)
|
||||
}
|
||||
|
||||
if req.Messages[0].Content != "What's the weather like in Paris Today?" {
|
||||
t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
|
||||
}
|
||||
|
||||
if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
|
||||
t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
|
||||
}
|
||||
|
||||
if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
|
||||
t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "chat handler error forwarding",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{{Role: "user", Content: 2}},
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||
if resp.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", resp.Code)
|
||||
}
|
||||
|
||||
if !strings.Contains(resp.Body.String(), "invalid message content type") {
|
||||
t.Fatalf("error was not forwarded")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/chat", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
|
||||
|
||||
tc.Setup(t, req)
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
tc.Expected(t, capturedRequest, resp)
|
||||
|
||||
capturedRequest = nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletionsMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
Name string
|
||||
Setup func(t *testing.T, req *http.Request)
|
||||
Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
|
||||
}
|
||||
|
||||
var capturedRequest *api.GenerateRequest
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
Name: "completions handler",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
temp := float32(0.8)
|
||||
body := CompletionRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Temperature: &temp,
|
||||
Stop: []string{"\n", "stop"},
|
||||
Suffix: "suffix",
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
||||
if req.Prompt != "Hello" {
|
||||
t.Fatalf("expected 'Hello', got %s", req.Prompt)
|
||||
}
|
||||
|
||||
if req.Options["temperature"] != 1.6 {
|
||||
t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
|
||||
}
|
||||
|
||||
stopTokens, ok := req.Options["stop"].([]any)
|
||||
|
||||
if !ok {
|
||||
t.Fatalf("expected stop tokens to be a list")
|
||||
}
|
||||
|
||||
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
|
||||
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
|
||||
}
|
||||
|
||||
if req.Suffix != "suffix" {
|
||||
t.Fatalf("expected 'suffix', got %s", req.Suffix)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "completions handler error forwarding",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := CompletionRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Temperature: nil,
|
||||
Stop: []int{1, 2},
|
||||
Suffix: "suffix",
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
||||
if resp.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", resp.Code)
|
||||
}
|
||||
|
||||
if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
|
||||
t.Fatalf("error was not forwarded")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
|
||||
|
||||
tc.Setup(t, req)
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
tc.Expected(t, capturedRequest, resp)
|
||||
|
||||
capturedRequest = nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddingsMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
Name string
|
||||
Setup func(t *testing.T, req *http.Request)
|
||||
Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
|
||||
}
|
||||
|
||||
var capturedRequest *api.EmbedRequest
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
Name: "embed handler single input",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := EmbedRequest{
|
||||
Input: "Hello",
|
||||
Model: "test-model",
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
||||
if req.Input != "Hello" {
|
||||
t.Fatalf("expected 'Hello', got %s", req.Input)
|
||||
}
|
||||
|
||||
if req.Model != "test-model" {
|
||||
t.Fatalf("expected 'test-model', got %s", req.Model)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "embed handler batch input",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := EmbedRequest{
|
||||
Input: []string{"Hello", "World"},
|
||||
Model: "test-model",
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
||||
input, ok := req.Input.([]any)
|
||||
|
||||
if !ok {
|
||||
t.Fatalf("expected input to be a list")
|
||||
}
|
||||
|
||||
if input[0].(string) != "Hello" {
|
||||
t.Fatalf("expected 'Hello', got %s", input[0])
|
||||
}
|
||||
|
||||
if input[1].(string) != "World" {
|
||||
t.Fatalf("expected 'World', got %s", input[1])
|
||||
}
|
||||
|
||||
if req.Model != "test-model" {
|
||||
t.Fatalf("expected 'test-model', got %s", req.Model)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "embed handler error forwarding",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := EmbedRequest{
|
||||
Model: "test-model",
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
||||
if resp.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", resp.Code)
|
||||
}
|
||||
|
||||
if !strings.Contains(resp.Body.String(), "invalid input") {
|
||||
t.Fatalf("error was not forwarded")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
|
||||
|
||||
tc.Setup(t, req)
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
tc.Expected(t, capturedRequest, resp)
|
||||
|
||||
capturedRequest = nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareResponses(t *testing.T) {
|
||||
type testCase struct {
|
||||
Name string
|
||||
Method string
|
||||
@@ -29,188 +409,6 @@ func TestMiddleware(t *testing.T) {
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
Name: "chat handler",
|
||||
Method: http.MethodPost,
|
||||
Path: "/api/chat",
|
||||
TestPath: "/api/chat",
|
||||
Handler: ChatMiddleware,
|
||||
Endpoint: func(c *gin.Context) {
|
||||
var chatReq api.ChatRequest
|
||||
if err := c.ShouldBindJSON(&chatReq); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
userMessage := chatReq.Messages[0].Content
|
||||
var assistantMessage string
|
||||
|
||||
switch userMessage {
|
||||
case "Hello":
|
||||
assistantMessage = "Hello!"
|
||||
default:
|
||||
assistantMessage = "I'm not sure how to respond to that."
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, api.ChatResponse{
|
||||
Message: api.Message{
|
||||
Role: "assistant",
|
||||
Content: assistantMessage,
|
||||
},
|
||||
})
|
||||
},
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{{Role: "user", Content: "Hello"}},
|
||||
}
|
||||
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
|
||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
},
|
||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
|
||||
var chatResp ChatCompletion
|
||||
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if chatResp.Object != "chat.completion" {
|
||||
t.Fatalf("expected chat.completion, got %s", chatResp.Object)
|
||||
}
|
||||
|
||||
if chatResp.Choices[0].Message.Content != "Hello!" {
|
||||
t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "completions handler",
|
||||
Method: http.MethodPost,
|
||||
Path: "/api/generate",
|
||||
TestPath: "/api/generate",
|
||||
Handler: CompletionsMiddleware,
|
||||
Endpoint: func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
Response: "Hello!",
|
||||
})
|
||||
},
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := CompletionRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
}
|
||||
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
|
||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
},
|
||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
var completionResp Completion
|
||||
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if completionResp.Object != "text_completion" {
|
||||
t.Fatalf("expected text_completion, got %s", completionResp.Object)
|
||||
}
|
||||
|
||||
if completionResp.Choices[0].Text != "Hello!" {
|
||||
t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "completions handler with params",
|
||||
Method: http.MethodPost,
|
||||
Path: "/api/generate",
|
||||
TestPath: "/api/generate",
|
||||
Handler: CompletionsMiddleware,
|
||||
Endpoint: func(c *gin.Context) {
|
||||
var generateReq api.GenerateRequest
|
||||
if err := c.ShouldBindJSON(&generateReq); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
||||
return
|
||||
}
|
||||
|
||||
temperature := generateReq.Options["temperature"].(float64)
|
||||
var assistantMessage string
|
||||
|
||||
switch temperature {
|
||||
case 1.6:
|
||||
assistantMessage = "Received temperature of 1.6"
|
||||
default:
|
||||
assistantMessage = fmt.Sprintf("Received temperature of %f", temperature)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
Response: assistantMessage,
|
||||
})
|
||||
},
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
temp := float32(0.8)
|
||||
body := CompletionRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Temperature: &temp,
|
||||
}
|
||||
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
|
||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
},
|
||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
var completionResp Completion
|
||||
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if completionResp.Object != "text_completion" {
|
||||
t.Fatalf("expected text_completion, got %s", completionResp.Object)
|
||||
}
|
||||
|
||||
if completionResp.Choices[0].Text != "Received temperature of 1.6" {
|
||||
t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "completions handler with error",
|
||||
Method: http.MethodPost,
|
||||
Path: "/api/generate",
|
||||
TestPath: "/api/generate",
|
||||
Handler: CompletionsMiddleware,
|
||||
Endpoint: func(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
||||
},
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := CompletionRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
}
|
||||
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
|
||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
},
|
||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
if resp.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", resp.Code)
|
||||
}
|
||||
|
||||
if !strings.Contains(resp.Body.String(), `"invalid request"`) {
|
||||
t.Fatalf("error was not forwarded")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "list handler",
|
||||
Method: http.MethodGet,
|
||||
@@ -227,8 +425,6 @@ func TestMiddleware(t *testing.T) {
|
||||
})
|
||||
},
|
||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
|
||||
var listResp ListCompletion
|
||||
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -292,6 +488,8 @@ func TestMiddleware(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
|
||||
tc.Expected(t, resp)
|
||||
})
|
||||
}
|
||||
|
@@ -107,9 +107,12 @@ function gatherDependencies() {
|
||||
|
||||
# TODO - this varies based on host build system and MSVC version - drive from dumpbin output
|
||||
# currently works for Win11 + MSVC 2019 + Cuda V11
|
||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140*.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||
foreach ($part in $("runtime", "stdio", "filesystem", "math", "convert", "heap", "string", "time", "locale", "environment")) {
|
||||
cp "$env:VCToolsRedistDir\..\..\..\Tools\Llvm\x64\bin\api-ms-win-crt-${part}*.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||
}
|
||||
|
||||
|
||||
cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"
|
||||
|
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -141,6 +142,32 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||
b.err = b.run(ctx, requestURL, opts)
|
||||
}
|
||||
|
||||
func newBackoff(maxBackoff time.Duration) func(ctx context.Context) error {
|
||||
var n int
|
||||
return func(ctx context.Context) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
n++
|
||||
|
||||
// n^2 backoff timer is a little smoother than the
|
||||
// common choice of 2^n.
|
||||
d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
|
||||
// Randomize the delay between 0.5-1.5 x msec, in order
|
||||
// to prevent accidental "thundering herd" problems.
|
||||
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
|
||||
t := time.NewTimer(d)
|
||||
defer t.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-t.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
|
||||
defer blobDownloadManager.Delete(b.Digest)
|
||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||
@@ -153,6 +180,52 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||
|
||||
_ = file.Truncate(b.Total)
|
||||
|
||||
directURL, err := func() (*url.URL, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
backoff := newBackoff(10 * time.Second)
|
||||
for {
|
||||
// shallow clone opts to be used in the closure
|
||||
// without affecting the outer opts.
|
||||
newOpts := new(registryOptions)
|
||||
*newOpts = *opts
|
||||
|
||||
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) > 10 {
|
||||
return errors.New("maxium redirects exceeded (10) for directURL")
|
||||
}
|
||||
|
||||
// if the hostname is the same, allow the redirect
|
||||
if req.URL.Hostname() == requestURL.Hostname() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// stop at the first redirect that is not
|
||||
// the same hostname as the original
|
||||
// request.
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, nil, nil, newOpts)
|
||||
if err != nil {
|
||||
slog.Warn("failed to get direct URL; backing off and retrying", "err", err)
|
||||
if err := backoff(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusTemporaryRedirect {
|
||||
return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode)
|
||||
}
|
||||
return resp.Location()
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
g, inner := errgroup.WithContext(ctx)
|
||||
g.SetLimit(numDownloadParts)
|
||||
for i := range b.Parts {
|
||||
@@ -165,7 +238,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||
var err error
|
||||
for try := 0; try < maxRetries; try++ {
|
||||
w := io.NewOffsetWriter(file, part.StartsAt())
|
||||
err = b.downloadChunk(inner, requestURL, w, part, opts)
|
||||
err = b.downloadChunk(inner, directURL, w, part, opts)
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
|
||||
// return immediately if the context is canceled or the device is out of space
|
||||
|
@@ -34,15 +34,28 @@ import (
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
var (
|
||||
errCapabilities = errors.New("does not support")
|
||||
errCapabilityCompletion = errors.New("completion")
|
||||
errCapabilityTools = errors.New("tools")
|
||||
errCapabilityInsert = errors.New("insert")
|
||||
)
|
||||
|
||||
type Capability string
|
||||
|
||||
const CapabilityCompletion = Capability("completion")
|
||||
const (
|
||||
CapabilityCompletion = Capability("completion")
|
||||
CapabilityTools = Capability("tools")
|
||||
CapabilityInsert = Capability("insert")
|
||||
)
|
||||
|
||||
type registryOptions struct {
|
||||
Insecure bool
|
||||
Username string
|
||||
Password string
|
||||
Token string
|
||||
|
||||
CheckRedirect func(req *http.Request, via []*http.Request) error
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
@@ -62,7 +75,10 @@ type Model struct {
|
||||
Template *template.Template
|
||||
}
|
||||
|
||||
func (m *Model) Has(caps ...Capability) bool {
|
||||
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
|
||||
// any missing or unknown capabilities
|
||||
func (m *Model) CheckCapabilities(caps ...Capability) error {
|
||||
var errs []error
|
||||
for _, cap := range caps {
|
||||
switch cap {
|
||||
case CapabilityCompletion:
|
||||
@@ -81,15 +97,28 @@ func (m *Model) Has(caps ...Capability) bool {
|
||||
}
|
||||
|
||||
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
|
||||
return false
|
||||
errs = append(errs, errCapabilityCompletion)
|
||||
}
|
||||
case CapabilityTools:
|
||||
if !slices.Contains(m.Template.Vars(), "tools") {
|
||||
errs = append(errs, errCapabilityTools)
|
||||
}
|
||||
case CapabilityInsert:
|
||||
vars := m.Template.Vars()
|
||||
if !slices.Contains(vars, "suffix") {
|
||||
errs = append(errs, errCapabilityInsert)
|
||||
}
|
||||
default:
|
||||
slog.Error("unknown capability", "capability", cap)
|
||||
return false
|
||||
return fmt.Errorf("unknown capability: %s", cap)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
if err := errors.Join(errs...); err != nil {
|
||||
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) String() string {
|
||||
@@ -465,6 +494,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||
layers = append(layers, baseLayer.Layer)
|
||||
}
|
||||
case "license", "template", "system":
|
||||
if c.Name == "template" {
|
||||
if _, err := template.Parse(c.Args); err != nil {
|
||||
return fmt.Errorf("%w: %s", errBadTemplate, err)
|
||||
}
|
||||
}
|
||||
|
||||
if c.Name != "license" {
|
||||
// replace
|
||||
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
||||
@@ -1098,7 +1133,9 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
|
||||
req.ContentLength = contentLength
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
resp, err := (&http.Client{
|
||||
CheckRedirect: regOpts.CheckRedirect,
|
||||
}).Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
134
server/model.go
134
server/model.go
@@ -4,6 +4,7 @@ import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -11,6 +12,9 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"text/template/parse"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/convert"
|
||||
@@ -259,13 +263,27 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
||||
if t, err := template.Named(s); err != nil {
|
||||
slog.Debug("template detection", "error", err)
|
||||
} else {
|
||||
tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tmpl.status = fmt.Sprintf("using autodetected template %s", t.Name)
|
||||
layers = append(layers, &layerGGML{tmpl, nil})
|
||||
layer.status = fmt.Sprintf("using autodetected template %s", t.Name)
|
||||
layers = append(layers, &layerGGML{layer, nil})
|
||||
|
||||
if t.Parameters != nil {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(t.Parameters); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers = append(layers, &layerGGML{layer, nil})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -289,3 +307,113 @@ func detectContentType(r io.Reader) (string, error) {
|
||||
|
||||
return "unknown", nil
|
||||
}
|
||||
|
||||
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||
// mxyng: this only really works if the input contains tool calls in some JSON format
|
||||
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||
// create a subtree from the node that ranges over .ToolCalls
|
||||
tmpl := m.Template.Subtree(func(n parse.Node) bool {
|
||||
if t, ok := n.(*parse.RangeNode); ok {
|
||||
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
|
||||
if tmpl == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
|
||||
"ToolCalls": {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "@@name@@",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"@@argument@@": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var kv map[string]any
|
||||
// execute the subtree with placeholders to identify the keys
|
||||
// trim any commands that might exist in the template
|
||||
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// find the keys that correspond to the name and arguments fields
|
||||
var name, arguments string
|
||||
for k, v := range kv {
|
||||
switch v.(type) {
|
||||
case string:
|
||||
name = k
|
||||
case map[string]any:
|
||||
arguments = k
|
||||
}
|
||||
}
|
||||
|
||||
if name == "" || arguments == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var objs []map[string]any
|
||||
for offset := 0; offset < len(s); {
|
||||
var obj map[string]any
|
||||
decoder := json.NewDecoder(strings.NewReader(s[offset:]))
|
||||
if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
break
|
||||
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
|
||||
// skip over any syntax errors
|
||||
offset += int(syntax.Offset)
|
||||
} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
|
||||
// skip over any unmarshalable types
|
||||
offset += int(unmarshalType.Offset)
|
||||
} else if err != nil {
|
||||
slog.Error("parseToolCalls", "error", err)
|
||||
return nil, false
|
||||
} else {
|
||||
offset += int(decoder.InputOffset())
|
||||
|
||||
// collect all nested objects
|
||||
var collect func(any) []map[string]any
|
||||
collect = func(obj any) (all []map[string]any) {
|
||||
switch o := obj.(type) {
|
||||
case map[string]any:
|
||||
all = append(all, o)
|
||||
for _, v := range o {
|
||||
all = append(all, collect(v)...)
|
||||
}
|
||||
case []any:
|
||||
for _, v := range o {
|
||||
all = append(all, collect(v)...)
|
||||
}
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
objs = append(objs, collect(obj)...)
|
||||
}
|
||||
}
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
for _, kv := range objs {
|
||||
n, nok := kv[name].(string)
|
||||
a, aok := kv[arguments].(map[string]any)
|
||||
if nok && aok {
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: n,
|
||||
Arguments: a,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return toolCalls, len(toolCalls) > 0
|
||||
}
|
||||
|
@@ -3,7 +3,9 @@ package server
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -11,7 +13,9 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
func createZipFile(t *testing.T, name string) *os.File {
|
||||
@@ -110,3 +114,123 @@ func TestExtractFromZipFile(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
||||
t.Helper()
|
||||
|
||||
bts, err := os.ReadFile(filepath.Join(base, name))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return bytes.NewBuffer(bts)
|
||||
}
|
||||
|
||||
func TestExecuteWithTools(t *testing.T) {
|
||||
p := filepath.Join("testdata", "tools")
|
||||
cases := []struct {
|
||||
model string
|
||||
output string
|
||||
ok bool
|
||||
}{
|
||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||
|
||||
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
|
||||
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
||||
|
||||
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||
{"command-r-plus", "Action: ```json" + `
|
||||
[
|
||||
{
|
||||
"tool_name": "get_current_weather",
|
||||
"parameters": {
|
||||
"format": "fahrenheit",
|
||||
"location": "San Francisco, CA"
|
||||
}
|
||||
},
|
||||
{
|
||||
"tool_name": "get_current_weather",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "Toronto, Canada"
|
||||
}
|
||||
}
|
||||
]
|
||||
` + "```", true},
|
||||
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||
{"llama3-groq-tool-use", `<tool_call>
|
||||
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
|
||||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
|
||||
</tool_call>`, true},
|
||||
{"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
|
||||
}
|
||||
|
||||
var tools []api.Tool
|
||||
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var messages []api.Message
|
||||
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
calls := []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"format": "fahrenheit",
|
||||
"location": "San Francisco, CA",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"format": "celsius",
|
||||
"location": "Toronto, Canada",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.model, func(t *testing.T) {
|
||||
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("template", func(t *testing.T) {
|
||||
var actual bytes.Buffer
|
||||
if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("parse", func(t *testing.T) {
|
||||
m := &Model{Template: tmpl}
|
||||
actual, ok := m.parseToolCalls(tt.output)
|
||||
if ok != tt.ok {
|
||||
t.Fatalf("expected %t, got %t", tt.ok, ok)
|
||||
}
|
||||
|
||||
if tt.ok {
|
||||
if diff := cmp.Diff(actual, calls); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
239
server/prompt.go
239
server/prompt.go
@@ -1,217 +1,74 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"bytes"
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"text/template/parse"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
// isResponseNode checks if the node contains .Response
|
||||
func isResponseNode(node *parse.ActionNode) bool {
|
||||
for _, cmd := range node.Pipe.Cmds {
|
||||
for _, arg := range cmd.Args {
|
||||
if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 {
|
||||
if fieldNode.Ident[0] == "Response" {
|
||||
return true
|
||||
}
|
||||
type tokenizeFunc func(context.Context, string) ([]int, error)
|
||||
|
||||
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
||||
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||
// latest message and 2) system messages
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
|
||||
var system []api.Message
|
||||
// always include the last message
|
||||
n := len(msgs) - 1
|
||||
// in reverse, find all messages that fit into context window
|
||||
for i := n - 1; i >= 0; i-- {
|
||||
system = make([]api.Message, 0)
|
||||
for j := range i {
|
||||
if msgs[j].Role == "system" {
|
||||
system = append(system, msgs[j])
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// formatTemplateForResponse formats the template AST to:
|
||||
// 1. remove all nodes after the first .Response (if generate=true)
|
||||
// 2. add a .Response node to the end if it doesn't exist
|
||||
// TODO(jmorganca): this should recursively cut the template before the first .Response
|
||||
func formatTemplateForResponse(tmpl *template.Template, generate bool) {
|
||||
var found bool
|
||||
for i, node := range tmpl.Tree.Root.Nodes {
|
||||
if actionNode, ok := node.(*parse.ActionNode); ok {
|
||||
if isResponseNode(actionNode) {
|
||||
found = true
|
||||
if generate {
|
||||
tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1]
|
||||
break
|
||||
}
|
||||
}
|
||||
var b bytes.Buffer
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
// add the response node if it doesn't exist
|
||||
responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}}
|
||||
responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}}
|
||||
responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode}
|
||||
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode)
|
||||
}
|
||||
}
|
||||
|
||||
// Prompt renders a prompt from a template. If generate is set to true,
|
||||
// the response and parts of the template following it are not rendered
|
||||
func Prompt(tmpl *template.Template, system, prompt, response string, generate bool) (string, error) {
|
||||
formatTemplateForResponse(tmpl, generate)
|
||||
|
||||
vars := map[string]any{
|
||||
"System": system,
|
||||
"Prompt": prompt,
|
||||
"Response": response,
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
if err := tmpl.Execute(&sb, vars); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func countTokens(tmpl *template.Template, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
|
||||
rendered, err := Prompt(tmpl, system, prompt, response, false)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
tokens, err := encode(rendered)
|
||||
if err != nil {
|
||||
slog.Error("failed to encode prompt", "err", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(tokens), err
|
||||
}
|
||||
|
||||
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
|
||||
func ChatPrompt(tmpl *template.Template, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
|
||||
type prompt struct {
|
||||
System string
|
||||
Prompt string
|
||||
Response string
|
||||
|
||||
images []int
|
||||
tokens int
|
||||
}
|
||||
|
||||
var p prompt
|
||||
|
||||
// iterate through messages to build up {system,user,response} prompts
|
||||
var imgId int
|
||||
var prompts []prompt
|
||||
for _, msg := range messages {
|
||||
switch strings.ToLower(msg.Role) {
|
||||
case "system":
|
||||
if p.System != "" || p.Prompt != "" || p.Response != "" {
|
||||
prompts = append(prompts, p)
|
||||
p = prompt{}
|
||||
}
|
||||
|
||||
p.System = msg.Content
|
||||
case "user":
|
||||
if p.Prompt != "" || p.Response != "" {
|
||||
prompts = append(prompts, p)
|
||||
p = prompt{}
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
for range msg.Images {
|
||||
fmt.Fprintf(&sb, "[img-%d] ", imgId)
|
||||
p.images = append(p.images, imgId)
|
||||
imgId += 1
|
||||
}
|
||||
|
||||
sb.WriteString(msg.Content)
|
||||
p.Prompt = sb.String()
|
||||
case "assistant":
|
||||
if p.Response != "" {
|
||||
prompts = append(prompts, p)
|
||||
p = prompt{}
|
||||
}
|
||||
|
||||
p.Response = msg.Content
|
||||
default:
|
||||
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||
}
|
||||
}
|
||||
|
||||
// add final prompt
|
||||
if p.System != "" || p.Prompt != "" || p.Response != "" {
|
||||
prompts = append(prompts, p)
|
||||
}
|
||||
|
||||
// calculate token lengths for each prompt, estimating 768 tokens per images
|
||||
for i, p := range prompts {
|
||||
tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
|
||||
s, err := tokenize(ctx, b.String())
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
prompts[i].tokens = tokens + len(prompts[i].images)*768
|
||||
}
|
||||
|
||||
// truncate images and prompts starting from the beginning of the list
|
||||
// until either one prompt remains or the total tokens fits the context window
|
||||
// TODO (jmorganca): this doesn't account for the context window room required for the response
|
||||
for {
|
||||
var required int
|
||||
for _, p := range prompts {
|
||||
required += p.tokens
|
||||
c := len(s)
|
||||
if m.ProjectorPaths != nil {
|
||||
for _, m := range msgs[i:] {
|
||||
// images are represented as 768 sized embeddings
|
||||
// TODO: get embedding length from project metadata
|
||||
c += 768 * len(m.Images)
|
||||
}
|
||||
}
|
||||
|
||||
required += 1 // for bos token
|
||||
|
||||
if required <= window {
|
||||
slog.Debug("prompt now fits in context window", "required", required, "window", window)
|
||||
if c > opts.NumCtx {
|
||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
||||
break
|
||||
} else {
|
||||
n = i
|
||||
}
|
||||
|
||||
prompt := &prompts[0]
|
||||
|
||||
if len(prompt.images) > 1 {
|
||||
img := prompt.images[0]
|
||||
slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window)
|
||||
prompt.images = prompt.images[1:]
|
||||
prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1)
|
||||
prompt.tokens -= 768
|
||||
continue
|
||||
}
|
||||
|
||||
if len(prompts) > 1 {
|
||||
slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window)
|
||||
system := prompt.System
|
||||
prompts = prompts[1:]
|
||||
|
||||
if system != "" && prompts[0].System == "" {
|
||||
prompts[0].System = system
|
||||
|
||||
tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
prompts[0].tokens = tokens + len(prompts[0].images)*768
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// stop truncating if there's only one prompt left
|
||||
break
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
for i, p := range prompts {
|
||||
// last prompt should leave the response unrendered (for completion)
|
||||
rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.WriteString(rendered)
|
||||
// truncate any messages that do not fit into the context window
|
||||
var b bytes.Buffer
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
for _, m := range msgs[n:] {
|
||||
for _, i := range m.Images {
|
||||
images = append(images, llm.ImageData{
|
||||
ID: len(images),
|
||||
Data: i,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return b.String(), images, nil
|
||||
}
|
||||
|
@@ -1,215 +1,209 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
func TestPrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
system string
|
||||
prompt string
|
||||
response string
|
||||
generate bool
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "simple prompt",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
system: "You are a Wizard.",
|
||||
prompt: "What are the potion ingredients?",
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "implicit response",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
system: "You are a Wizard.",
|
||||
prompt: "What are the potion ingredients?",
|
||||
response: "I don't know.",
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.",
|
||||
},
|
||||
{
|
||||
name: "response",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
|
||||
system: "You are a Wizard.",
|
||||
prompt: "What are the potion ingredients?",
|
||||
response: "I don't know.",
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
|
||||
},
|
||||
{
|
||||
name: "cut",
|
||||
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
|
||||
system: "You are a Wizard.",
|
||||
prompt: "What are the potion ingredients?",
|
||||
response: "I don't know.",
|
||||
generate: true,
|
||||
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.",
|
||||
},
|
||||
{
|
||||
name: "nocut",
|
||||
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
|
||||
system: "You are a Wizard.",
|
||||
prompt: "What are the potion ingredients?",
|
||||
response: "I don't know.",
|
||||
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.</assistant>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.Parse(tc.template)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got, err := Prompt(tmpl, tc.system, tc.prompt, tc.response, tc.generate)
|
||||
if err != nil {
|
||||
t.Errorf("error = %v", err)
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Errorf("got = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatPrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
messages []api.Message
|
||||
window int
|
||||
want string
|
||||
type expect struct {
|
||||
prompt string
|
||||
images [][]byte
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
limit int
|
||||
msgs []api.Message
|
||||
expect
|
||||
}{
|
||||
{
|
||||
name: "simple prompt",
|
||||
template: "[INST] {{ .Prompt }} [/INST]",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
name: "messages",
|
||||
limit: 64,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||
},
|
||||
window: 1024,
|
||||
want: "[INST] Hello [/INST]",
|
||||
},
|
||||
{
|
||||
name: "with system message",
|
||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a Wizard."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
name: "truncate messages",
|
||||
limit: 1,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "A test. And a thumping good one at that, I'd wager. ",
|
||||
},
|
||||
window: 1024,
|
||||
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
|
||||
},
|
||||
{
|
||||
name: "with response",
|
||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a Wizard."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "I am?"},
|
||||
name: "truncate messages with image",
|
||||
limit: 64,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("something"),
|
||||
},
|
||||
},
|
||||
window: 1024,
|
||||
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST] I am?",
|
||||
},
|
||||
{
|
||||
name: "with implicit response",
|
||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a Wizard."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "I am?"},
|
||||
name: "truncate messages with images",
|
||||
limit: 64,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
},
|
||||
window: 1024,
|
||||
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]I am?",
|
||||
},
|
||||
{
|
||||
name: "with conversation",
|
||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a Wizard."},
|
||||
{Role: "user", Content: "What are the potion ingredients?"},
|
||||
{Role: "assistant", Content: "sugar"},
|
||||
{Role: "user", Content: "Anything else?"},
|
||||
name: "messages with images",
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0] You're a test, Harry! I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("something"),
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
},
|
||||
window: 1024,
|
||||
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ",
|
||||
},
|
||||
{
|
||||
name: "with truncation",
|
||||
template: "{{ .System }} {{ .Prompt }} {{ .Response }} ",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a Wizard."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "I am?"},
|
||||
{Role: "user", Content: "Why is the sky blue?"},
|
||||
{Role: "assistant", Content: "The sky is blue from rayleigh scattering"},
|
||||
name: "message with image tag",
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "You're a test, Harry! [img-0] I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("something"),
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
},
|
||||
window: 10,
|
||||
want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering",
|
||||
},
|
||||
{
|
||||
name: "images",
|
||||
template: "{{ .System }} {{ .Prompt }}",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a Wizard."},
|
||||
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
|
||||
name: "messages with interleaved images",
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "user", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "You're a test, Harry!\n\n[img-0]\n\n[img-1] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("something"),
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
},
|
||||
window: 1024,
|
||||
want: "You are a Wizard. [img-0] Hello",
|
||||
},
|
||||
{
|
||||
name: "images truncated",
|
||||
template: "{{ .System }} {{ .Prompt }}",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a Wizard."},
|
||||
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
|
||||
name: "truncate message with interleaved images",
|
||||
limit: 1024,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "user", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
},
|
||||
window: 1024,
|
||||
want: "You are a Wizard. [img-0] [img-1] Hello",
|
||||
},
|
||||
{
|
||||
name: "empty list",
|
||||
template: "{{ .System }} {{ .Prompt }}",
|
||||
messages: []api.Message{},
|
||||
window: 1024,
|
||||
want: "",
|
||||
name: "message with system prompt",
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are the Test Who Lived."},
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty prompt",
|
||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: ""},
|
||||
name: "out of order system",
|
||||
limit: 2048,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "system", Content: "You are the Test Who Lived."},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
|
||||
},
|
||||
window: 1024,
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
encode := func(s string) ([]int, error) {
|
||||
words := strings.Fields(s)
|
||||
return make([]int, len(words)), nil
|
||||
tmpl, err := template.Parse(`
|
||||
{{- if .System }}{{ .System }} {{ end }}
|
||||
{{- if .Prompt }}{{ .Prompt }} {{ end }}
|
||||
{{- if .Response }}{{ .Response }} {{ end }}`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.Parse(tc.template)
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got, err := ChatPrompt(tmpl, tc.messages, tc.window, encode)
|
||||
if err != nil {
|
||||
t.Errorf("error = %v", err)
|
||||
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Errorf("got: %q, want: %q", got, tc.want)
|
||||
if len(images) != len(tt.images) {
|
||||
t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
|
||||
}
|
||||
|
||||
for i := range images {
|
||||
if images[i].ID != i {
|
||||
t.Errorf("expected ID %d, got %d", i, images[i].ID)
|
||||
}
|
||||
|
||||
if !bytes.Equal(images[i].Data, tt.images[i]) {
|
||||
t.Errorf("expected %q, got %q", tt.images[i], images[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
660
server/routes.go
660
server/routes.go
@@ -1,14 +1,15 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
@@ -54,6 +55,9 @@ func init() {
|
||||
gin.SetMode(mode)
|
||||
}
|
||||
|
||||
var errRequired = errors.New("is required")
|
||||
var errBadTemplate = errors.New("template error")
|
||||
|
||||
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
||||
opts := api.DefaultOptions()
|
||||
if err := opts.FromMap(model.Options); err != nil {
|
||||
@@ -67,250 +71,220 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func isSupportedImageType(image []byte) bool {
|
||||
contentType := http.DetectContentType(image)
|
||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
|
||||
return slices.Contains(allowedTypes, contentType)
|
||||
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
|
||||
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
|
||||
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
|
||||
if name == "" {
|
||||
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
|
||||
}
|
||||
|
||||
model, err := GetModel(name)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if err := model.CheckCapabilities(caps...); err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
|
||||
}
|
||||
|
||||
opts, err := modelOptions(model, requestOpts)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-runnerCh:
|
||||
case err = <-errCh:
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
return runner.llama, model, &opts, nil
|
||||
}
|
||||
|
||||
func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
var req api.GenerateRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
case err != nil:
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// validate the request
|
||||
switch {
|
||||
case req.Model == "":
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
if req.Format != "" && req.Format != "json" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
|
||||
return
|
||||
case len(req.Format) > 0 && req.Format != "json":
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
|
||||
return
|
||||
case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
|
||||
} else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
||||
return
|
||||
}
|
||||
|
||||
for _, img := range req.Images {
|
||||
if !isSupportedImageType(img) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
|
||||
return
|
||||
}
|
||||
caps := []Capability{CapabilityCompletion}
|
||||
if req.Suffix != "" {
|
||||
caps = append(caps, CapabilityInsert)
|
||||
}
|
||||
|
||||
model, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
var pErr *fs.PathError
|
||||
if errors.As(err, &pErr) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||
if errors.Is(err, errCapabilityCompletion) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
||||
return
|
||||
} else if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !model.Has(CapabilityCompletion) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)})
|
||||
return
|
||||
}
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
opts, err := modelOptions(model, req.Options)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-rCh:
|
||||
case err = <-eCh:
|
||||
handleErrorResponse(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// an empty request loads the model
|
||||
// note: for a short while template was used in lieu
|
||||
// of `raw` mode so we need to check for it too
|
||||
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
||||
if req.Prompt == "" {
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
tmpl, err := template.Parse(req.Template)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
images := make([]llm.ImageData, len(req.Images))
|
||||
for i := range req.Images {
|
||||
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
|
||||
}
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
var prompt string
|
||||
switch {
|
||||
case req.Raw:
|
||||
prompt = req.Prompt
|
||||
case req.Prompt != "":
|
||||
if req.Template == "" {
|
||||
tmpl = model.Template
|
||||
prompt := req.Prompt
|
||||
if !req.Raw {
|
||||
tmpl := m.Template
|
||||
if req.Template != "" {
|
||||
tmpl, err = template.Parse(req.Template)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if req.System == "" {
|
||||
req.System = model.System
|
||||
}
|
||||
|
||||
slog.Debug("generate handler", "prompt", req.Prompt)
|
||||
slog.Debug("generate handler", "template", req.Template)
|
||||
slog.Debug("generate handler", "system", req.System)
|
||||
|
||||
var sb strings.Builder
|
||||
for i := range req.Images {
|
||||
fmt.Fprintf(&sb, "[img-%d] ", i)
|
||||
}
|
||||
|
||||
sb.WriteString(req.Prompt)
|
||||
|
||||
p, err := Prompt(tmpl, req.System, sb.String(), "", true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
sb.Reset()
|
||||
var b bytes.Buffer
|
||||
if req.Context != nil {
|
||||
prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
|
||||
s, err := r.Detokenize(c.Request.Context(), req.Context)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
sb.WriteString(prev)
|
||||
b.WriteString(s)
|
||||
}
|
||||
|
||||
sb.WriteString(p)
|
||||
|
||||
prompt = sb.String()
|
||||
}
|
||||
|
||||
slog.Debug("generate handler", "prompt", prompt)
|
||||
|
||||
ch := make(chan any)
|
||||
var generated strings.Builder
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
fn := func(r llm.CompletionResponse) {
|
||||
// Build up the full response
|
||||
if _, err := generated.WriteString(r.Content); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
var values template.Values
|
||||
if req.Suffix != "" {
|
||||
values.Prompt = prompt
|
||||
values.Suffix = req.Suffix
|
||||
} else {
|
||||
var msgs []api.Message
|
||||
if req.System != "" {
|
||||
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
|
||||
} else if m.System != "" {
|
||||
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
|
||||
}
|
||||
|
||||
resp := api.GenerateResponse{
|
||||
for _, i := range images {
|
||||
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
|
||||
}
|
||||
|
||||
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
||||
}
|
||||
|
||||
if err := tmpl.Execute(&b, values); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
prompt = b.String()
|
||||
}
|
||||
|
||||
slog.Debug("generate request", "prompt", prompt, "images", images)
|
||||
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
// TODO (jmorganca): avoid building the response twice both here and below
|
||||
var sb strings.Builder
|
||||
defer close(ch)
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
res := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Done: r.Done,
|
||||
Response: r.Content,
|
||||
DoneReason: r.DoneReason,
|
||||
Response: cr.Content,
|
||||
Done: cr.Done,
|
||||
DoneReason: cr.DoneReason,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: r.PromptEvalCount,
|
||||
PromptEvalDuration: r.PromptEvalDuration,
|
||||
EvalCount: r.EvalCount,
|
||||
EvalDuration: r.EvalDuration,
|
||||
PromptEvalCount: cr.PromptEvalCount,
|
||||
PromptEvalDuration: cr.PromptEvalDuration,
|
||||
EvalCount: cr.EvalCount,
|
||||
EvalDuration: cr.EvalDuration,
|
||||
},
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
resp.TotalDuration = time.Since(checkpointStart)
|
||||
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
if _, err := sb.WriteString(cr.Content); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
|
||||
if cr.Done {
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
|
||||
if !req.Raw {
|
||||
p, err := Prompt(tmpl, req.System, req.Prompt, generated.String(), false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// TODO (jmorganca): encode() should not strip special tokens
|
||||
tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
|
||||
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
resp.Context = append(req.Context, tokens...)
|
||||
res.Context = append(req.Context, tokens...)
|
||||
}
|
||||
}
|
||||
|
||||
ch <- resp
|
||||
}
|
||||
|
||||
var images []llm.ImageData
|
||||
for i := range req.Images {
|
||||
images = append(images, llm.ImageData{
|
||||
ID: i,
|
||||
Data: req.Images[i],
|
||||
})
|
||||
}
|
||||
|
||||
// Start prediction
|
||||
req := llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: images,
|
||||
Options: opts,
|
||||
}
|
||||
if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
|
||||
ch <- res
|
||||
}); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
// Accumulate responses into the final response
|
||||
var final api.GenerateResponse
|
||||
var r api.GenerateResponse
|
||||
var sb strings.Builder
|
||||
for resp := range ch {
|
||||
switch r := resp.(type) {
|
||||
for rr := range ch {
|
||||
switch t := rr.(type) {
|
||||
case api.GenerateResponse:
|
||||
sb.WriteString(r.Response)
|
||||
final = r
|
||||
sb.WriteString(t.Response)
|
||||
r = t
|
||||
case gin.H:
|
||||
if errorMsg, ok := r["error"].(string); ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
||||
return
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
|
||||
return
|
||||
msg, ok := t["error"].(string)
|
||||
if !ok {
|
||||
msg = "unexpected error format in response"
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
|
||||
return
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
final.Response = sb.String()
|
||||
c.JSON(http.StatusOK, final)
|
||||
r.Response = sb.String()
|
||||
c.JSON(http.StatusOK, r)
|
||||
return
|
||||
}
|
||||
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
var req api.EmbeddingRequest
|
||||
func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
var req api.EmbedRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
@@ -321,34 +295,122 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
truncate := true
|
||||
|
||||
if req.Truncate != nil && !*req.Truncate {
|
||||
truncate = false
|
||||
}
|
||||
|
||||
var input []string
|
||||
|
||||
switch i := req.Input.(type) {
|
||||
case string:
|
||||
if len(i) > 0 {
|
||||
input = append(input, i)
|
||||
}
|
||||
case []any:
|
||||
for _, v := range i {
|
||||
if _, ok := v.(string); !ok {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||
return
|
||||
}
|
||||
input = append(input, v.(string))
|
||||
}
|
||||
default:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||
return
|
||||
}
|
||||
|
||||
model, err := GetModel(req.Model)
|
||||
if len(input) == 0 {
|
||||
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
||||
return
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
var pErr *fs.PathError
|
||||
if errors.As(err, &pErr) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
kvData, err := getKVData(m.ModelPath, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
for i, s := range input {
|
||||
tokens, err := r.Tokenize(c.Request.Context(), s)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
opts, err := modelOptions(model, req.Options)
|
||||
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
||||
if len(tokens) > ctxLen {
|
||||
if !truncate {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
|
||||
return
|
||||
}
|
||||
|
||||
tokens = tokens[:ctxLen]
|
||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
input[i] = s
|
||||
}
|
||||
embeddings, err := r.Embed(c.Request.Context(), input)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
slog.Error("embedding generation failed", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||
return
|
||||
}
|
||||
|
||||
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-rCh:
|
||||
case err = <-eCh:
|
||||
handleErrorResponse(c, err)
|
||||
for i, e := range embeddings {
|
||||
embeddings[i] = normalize(e)
|
||||
}
|
||||
|
||||
resp := api.EmbedResponse{
|
||||
Model: req.Model,
|
||||
Embeddings: embeddings,
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func normalize(vec []float32) []float32 {
|
||||
var sum float32
|
||||
for _, v := range vec {
|
||||
sum += v * v
|
||||
}
|
||||
|
||||
norm := float32(0.0)
|
||||
if sum > 0 {
|
||||
norm = float32(1.0 / math.Sqrt(float64(sum)))
|
||||
}
|
||||
|
||||
for i := range vec {
|
||||
vec[i] *= norm
|
||||
}
|
||||
return vec
|
||||
}
|
||||
|
||||
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
var req api.EmbeddingRequest
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -358,13 +420,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
|
||||
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
|
||||
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||
return
|
||||
}
|
||||
|
||||
embedding := make([]float64, len(embeddings[0]))
|
||||
|
||||
for i, v := range embeddings[0] {
|
||||
embedding[i] = float64(v)
|
||||
}
|
||||
|
||||
resp := api.EmbeddingResponse{
|
||||
Embedding: embedding,
|
||||
}
|
||||
@@ -540,7 +609,9 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
|
||||
defer cancel()
|
||||
|
||||
quantization := cmp.Or(r.Quantize, r.Quantization)
|
||||
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
|
||||
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); errors.Is(err, errBadTemplate) {
|
||||
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
|
||||
} else if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
@@ -642,16 +713,9 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
m.System = req.System
|
||||
}
|
||||
|
||||
if req.Template != "" {
|
||||
m.Template, err = template.Parse(req.Template)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
msgs := make([]api.Message, 0)
|
||||
for _, msg := range m.Messages {
|
||||
msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
|
||||
msgs := make([]api.Message, len(m.Messages))
|
||||
for i, msg := range m.Messages {
|
||||
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
|
||||
}
|
||||
|
||||
n := model.ParseName(req.Model)
|
||||
@@ -994,6 +1058,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
r.POST("/api/pull", s.PullModelHandler)
|
||||
r.POST("/api/generate", s.GenerateHandler)
|
||||
r.POST("/api/chat", s.ChatHandler)
|
||||
r.POST("/api/embed", s.EmbedHandler)
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||
r.POST("/api/create", s.CreateModelHandler)
|
||||
r.POST("/api/push", s.PushModelHandler)
|
||||
@@ -1007,6 +1072,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
// Compatibility endpoints
|
||||
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||
r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
|
||||
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
|
||||
|
||||
@@ -1133,11 +1199,15 @@ func waitForStream(c *gin.Context, ch chan interface{}) {
|
||||
return
|
||||
}
|
||||
case gin.H:
|
||||
status, ok := r["status"].(int)
|
||||
if !ok {
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
if errorMsg, ok := r["error"].(string); ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
||||
c.JSON(status, gin.H{"error": errorMsg})
|
||||
return
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
|
||||
c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
|
||||
return
|
||||
}
|
||||
default:
|
||||
@@ -1214,132 +1284,67 @@ func (s *Server) ProcessHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
||||
}
|
||||
|
||||
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
||||
func chatPrompt(ctx context.Context, runner *runnerRef, template *template.Template, messages []api.Message, numCtx int) (string, error) {
|
||||
encode := func(s string) ([]int, error) {
|
||||
return runner.llama.Tokenize(ctx, s)
|
||||
}
|
||||
|
||||
prompt, err := ChatPrompt(template, messages, numCtx, encode)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return prompt, nil
|
||||
}
|
||||
|
||||
func (s *Server) ChatHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
|
||||
var req api.ChatRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
case err != nil:
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// validate the request
|
||||
switch {
|
||||
case req.Model == "":
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
case len(req.Format) > 0 && req.Format != "json":
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
|
||||
return
|
||||
caps := []Capability{CapabilityCompletion}
|
||||
if len(req.Tools) > 0 {
|
||||
caps = append(caps, CapabilityTools)
|
||||
}
|
||||
|
||||
model, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
var pErr *fs.PathError
|
||||
if errors.As(err, &pErr) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||
if errors.Is(err, errCapabilityCompletion) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
if !model.Has(CapabilityCompletion) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
opts, err := modelOptions(model, req.Options)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-rCh:
|
||||
case err = <-eCh:
|
||||
handleErrorResponse(c, err)
|
||||
} else if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
}
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
// if the first message is not a system message, then add the model's default system message
|
||||
if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
|
||||
req.Messages = append([]api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: model.System,
|
||||
},
|
||||
}, req.Messages...)
|
||||
}
|
||||
|
||||
prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// an empty request loads the model
|
||||
if len(req.Messages) == 0 || prompt == "" {
|
||||
resp := api.ChatResponse{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
if len(req.Messages) == 0 {
|
||||
c.JSON(http.StatusOK, api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: api.Message{Role: "assistant"},
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
Message: api.Message{Role: "assistant"},
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// only send images that are in the prompt
|
||||
var i int
|
||||
var images []llm.ImageData
|
||||
for _, m := range req.Messages {
|
||||
for _, img := range m.Images {
|
||||
if !isSupportedImageType(img) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
|
||||
images = append(images, llm.ImageData{Data: img, ID: i})
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
if req.Messages[0].Role != "system" && m.System != "" {
|
||||
req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...)
|
||||
}
|
||||
|
||||
slog.Debug("chat handler", "prompt", prompt, "images", len(images))
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("chat request", "images", len(images), "prompt", prompt)
|
||||
|
||||
ch := make(chan any)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
fn := func(r llm.CompletionResponse) {
|
||||
resp := api.ChatResponse{
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
}, func(r llm.CompletionResponse) {
|
||||
res := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||
@@ -1354,62 +1359,65 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
resp.TotalDuration = time.Since(checkpointStart)
|
||||
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
ch <- resp
|
||||
}
|
||||
|
||||
if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: images,
|
||||
Options: opts,
|
||||
}, fn); err != nil {
|
||||
ch <- res
|
||||
}); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
// Accumulate responses into the final response
|
||||
var final api.ChatResponse
|
||||
var resp api.ChatResponse
|
||||
var sb strings.Builder
|
||||
for resp := range ch {
|
||||
switch r := resp.(type) {
|
||||
for rr := range ch {
|
||||
switch t := rr.(type) {
|
||||
case api.ChatResponse:
|
||||
sb.WriteString(r.Message.Content)
|
||||
final = r
|
||||
sb.WriteString(t.Message.Content)
|
||||
resp = t
|
||||
case gin.H:
|
||||
if errorMsg, ok := r["error"].(string); ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
||||
return
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
|
||||
return
|
||||
msg, ok := t["error"].(string)
|
||||
if !ok {
|
||||
msg = "unexpected error format in response"
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
|
||||
return
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
final.Message = api.Message{Role: "assistant", Content: sb.String()}
|
||||
c.JSON(http.StatusOK, final)
|
||||
resp.Message.Content = sb.String()
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||
resp.Message.ToolCalls = toolCalls
|
||||
resp.Message.Content = ""
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func handleErrorResponse(c *gin.Context, err error) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
func handleScheduleError(c *gin.Context, name string, err error) {
|
||||
switch {
|
||||
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
case errors.Is(err, context.Canceled):
|
||||
c.JSON(499, gin.H{"error": "request canceled"})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, ErrMaxQueue) {
|
||||
case errors.Is(err, ErrMaxQueue):
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
|
||||
return
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
|
@@ -85,6 +85,8 @@ func checkFileExists(t *testing.T, p string, expect []string) {
|
||||
}
|
||||
|
||||
func TestCreateFromBin(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
envconfig.LoadConfig()
|
||||
@@ -111,6 +113,8 @@ func TestCreateFromBin(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCreateFromModel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
envconfig.LoadConfig()
|
||||
@@ -152,6 +156,8 @@ func TestCreateFromModel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCreateRemovesLayers(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
envconfig.LoadConfig()
|
||||
@@ -199,6 +205,8 @@ func TestCreateRemovesLayers(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCreateUnsetsSystem(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
envconfig.LoadConfig()
|
||||
@@ -255,6 +263,8 @@ func TestCreateUnsetsSystem(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCreateMergeParameters(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
envconfig.LoadConfig()
|
||||
@@ -358,6 +368,8 @@ func TestCreateMergeParameters(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCreateReplacesMessages(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
envconfig.LoadConfig()
|
||||
@@ -434,6 +446,8 @@ func TestCreateReplacesMessages(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCreateTemplateSystem(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
envconfig.LoadConfig()
|
||||
@@ -477,9 +491,47 @@ func TestCreateTemplateSystem(t *testing.T) {
|
||||
if string(system) != "Say bye!" {
|
||||
t.Errorf("expected \"Say bye!\", actual %s", system)
|
||||
}
|
||||
|
||||
t.Run("incomplete template", func(t *testing.T) {
|
||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status code 400, actual %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("template with unclosed if", func(t *testing.T) {
|
||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status code 400, actual %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("template with undefined function", func(t *testing.T) {
|
||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status code 400, actual %d", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateLicenses(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
envconfig.LoadConfig()
|
||||
@@ -526,6 +578,8 @@ func TestCreateLicenses(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCreateDetectTemplate(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
envconfig.LoadConfig()
|
||||
@@ -545,9 +599,10 @@ func TestCreateDetectTemplate(t *testing.T) {
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-2f8e594e6f34b1b4d36a246628eeb3365ce442303d656f1fcc69e821722acea0"),
|
||||
filepath.Join(p, "blobs", "sha256-542b217f179c7825eeb5bca3c77d2b75ed05bafbd3451d9188891a60a85337c6"),
|
||||
filepath.Join(p, "blobs", "sha256-0d79f567714c62c048378f2107fb332dabee0135d080c302d884317da9433cc5"),
|
||||
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
|
||||
filepath.Join(p, "blobs", "sha256-c608dc615584cd20d9d830363dabf8a4783ae5d34245c3d8c115edb3bc7b28e4"),
|
||||
filepath.Join(p, "blobs", "sha256-ea34c57ba5b78b740aafe2aeb74dc6507fc3ad14170b64c26a04fb9e36c88d75"),
|
||||
})
|
||||
})
|
||||
|
||||
|
@@ -8,12 +8,15 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestDelete(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
envconfig.LoadConfig()
|
||||
@@ -77,6 +80,8 @@ func TestDelete(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDeleteDuplicateLayers(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
var s Server
|
||||
|
714
server/routes_generate_test.go
Normal file
714
server/routes_generate_test.go
Normal file
@@ -0,0 +1,714 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/gpu"
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type mockRunner struct {
|
||||
llm.LlamaServer
|
||||
|
||||
// CompletionRequest is only valid until the next call to Completion
|
||||
llm.CompletionRequest
|
||||
llm.CompletionResponse
|
||||
}
|
||||
|
||||
func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
m.CompletionRequest = r
|
||||
fn(m.CompletionResponse)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
|
||||
for range strings.Fields(s) {
|
||||
tokens = append(tokens, len(tokens))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func newMockServer(mock *mockRunner) func(gpu.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
|
||||
return func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
return mock, nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateChat(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: gpu.GetGPUInfo,
|
||||
getCpuFn: gpu.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(context.TODO())
|
||||
|
||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Model: "test",
|
||||
Modelfile: fmt.Sprintf(`FROM %s
|
||||
TEMPLATE """
|
||||
{{- if .System }}System: {{ .System }} {{ end }}
|
||||
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
|
||||
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
|
||||
`, createBinFile(t, llm.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []llm.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("missing body", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, nil)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing model", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing capabilities chat", func(t *testing.T) {
|
||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Model: "bert",
|
||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
|
||||
"general.architecture": "bert",
|
||||
"bert.pooling_type": uint32(0),
|
||||
}, []llm.Tensor{})),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "bert",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load model", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var actual api.ChatResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if actual.Model != "test" {
|
||||
t.Errorf("expected model test, got %s", actual.Model)
|
||||
}
|
||||
|
||||
if !actual.Done {
|
||||
t.Errorf("expected done true, got false")
|
||||
}
|
||||
|
||||
if actual.DoneReason != "load" {
|
||||
t.Errorf("expected done reason load, got %s", actual.DoneReason)
|
||||
}
|
||||
})
|
||||
|
||||
checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
|
||||
t.Helper()
|
||||
|
||||
var actual api.ChatResponse
|
||||
if err := json.NewDecoder(body).Decode(&actual); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if actual.Model != model {
|
||||
t.Errorf("expected model test, got %s", actual.Model)
|
||||
}
|
||||
|
||||
if !actual.Done {
|
||||
t.Errorf("expected done false, got true")
|
||||
}
|
||||
|
||||
if actual.DoneReason != "stop" {
|
||||
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(actual.Message, api.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
}); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
if actual.PromptEvalCount == 0 {
|
||||
t.Errorf("expected prompt eval count > 0, got 0")
|
||||
}
|
||||
|
||||
if actual.PromptEvalDuration == 0 {
|
||||
t.Errorf("expected prompt eval duration > 0, got 0")
|
||||
}
|
||||
|
||||
if actual.EvalCount == 0 {
|
||||
t.Errorf("expected eval count > 0, got 0")
|
||||
}
|
||||
|
||||
if actual.EvalDuration == 0 {
|
||||
t.Errorf("expected eval duration > 0, got 0")
|
||||
}
|
||||
|
||||
if actual.LoadDuration == 0 {
|
||||
t.Errorf("expected load duration > 0, got 0")
|
||||
}
|
||||
|
||||
if actual.TotalDuration == 0 {
|
||||
t.Errorf("expected total duration > 0, got 0")
|
||||
}
|
||||
}
|
||||
|
||||
mock.CompletionResponse.Content = "Hi!"
|
||||
t.Run("messages", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
checkChatResponse(t, w.Body, "test", "Hi!")
|
||||
})
|
||||
|
||||
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Model: "test-system",
|
||||
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("messages with model system", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-system",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
checkChatResponse(t, w.Body, "test-system", "Hi!")
|
||||
})
|
||||
|
||||
mock.CompletionResponse.Content = "Abra kadabra!"
|
||||
t.Run("messages with system", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-system",
|
||||
Messages: []api.Message{
|
||||
{Role: "system", Content: "You can perform magic tricks."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||
})
|
||||
|
||||
t.Run("messages with interleaved system", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-system",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
{Role: "assistant", Content: "I can help you with that."},
|
||||
{Role: "system", Content: "You can perform magic tricks."},
|
||||
{Role: "user", Content: "Help me write tests."},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerate(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: gpu.GetGPUInfo,
|
||||
getCpuFn: gpu.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(context.TODO())
|
||||
|
||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Model: "test",
|
||||
Modelfile: fmt.Sprintf(`FROM %s
|
||||
TEMPLATE """
|
||||
{{- if .System }}System: {{ .System }} {{ end }}
|
||||
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
|
||||
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
|
||||
`, createBinFile(t, llm.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []llm.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("missing body", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, nil)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing model", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing capabilities generate", func(t *testing.T) {
|
||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Model: "bert",
|
||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
|
||||
"general.architecture": "bert",
|
||||
"bert.pooling_type": uint32(0),
|
||||
}, []llm.Tensor{})),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "bert",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing capabilities suffix", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "def add(",
|
||||
Suffix: " return c",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("load model", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var actual api.GenerateResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if actual.Model != "test" {
|
||||
t.Errorf("expected model test, got %s", actual.Model)
|
||||
}
|
||||
|
||||
if !actual.Done {
|
||||
t.Errorf("expected done true, got false")
|
||||
}
|
||||
|
||||
if actual.DoneReason != "load" {
|
||||
t.Errorf("expected done reason load, got %s", actual.DoneReason)
|
||||
}
|
||||
})
|
||||
|
||||
checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
|
||||
t.Helper()
|
||||
|
||||
var actual api.GenerateResponse
|
||||
if err := json.NewDecoder(body).Decode(&actual); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if actual.Model != model {
|
||||
t.Errorf("expected model test, got %s", actual.Model)
|
||||
}
|
||||
|
||||
if !actual.Done {
|
||||
t.Errorf("expected done false, got true")
|
||||
}
|
||||
|
||||
if actual.DoneReason != "stop" {
|
||||
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
|
||||
}
|
||||
|
||||
if actual.Response != content {
|
||||
t.Errorf("expected response %s, got %s", content, actual.Response)
|
||||
}
|
||||
|
||||
if actual.Context == nil {
|
||||
t.Errorf("expected context not nil")
|
||||
}
|
||||
|
||||
if actual.PromptEvalCount == 0 {
|
||||
t.Errorf("expected prompt eval count > 0, got 0")
|
||||
}
|
||||
|
||||
if actual.PromptEvalDuration == 0 {
|
||||
t.Errorf("expected prompt eval duration > 0, got 0")
|
||||
}
|
||||
|
||||
if actual.EvalCount == 0 {
|
||||
t.Errorf("expected eval count > 0, got 0")
|
||||
}
|
||||
|
||||
if actual.EvalDuration == 0 {
|
||||
t.Errorf("expected eval duration > 0, got 0")
|
||||
}
|
||||
|
||||
if actual.LoadDuration == 0 {
|
||||
t.Errorf("expected load duration > 0, got 0")
|
||||
}
|
||||
|
||||
if actual.TotalDuration == 0 {
|
||||
t.Errorf("expected total duration > 0, got 0")
|
||||
}
|
||||
}
|
||||
|
||||
mock.CompletionResponse.Content = "Hi!"
|
||||
t.Run("prompt", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "Hello!",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
checkGenerateResponse(t, w.Body, "test", "Hi!")
|
||||
})
|
||||
|
||||
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Model: "test-system",
|
||||
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("prompt with model system", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-system",
|
||||
Prompt: "Hello!",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
checkGenerateResponse(t, w.Body, "test-system", "Hi!")
|
||||
})
|
||||
|
||||
mock.CompletionResponse.Content = "Abra kadabra!"
|
||||
t.Run("prompt with system", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-system",
|
||||
Prompt: "Hello!",
|
||||
System: "You can perform magic tricks.",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||
})
|
||||
|
||||
t.Run("prompt with template", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-system",
|
||||
Prompt: "Help me write tests.",
|
||||
System: "You can perform magic tricks.",
|
||||
Template: `{{- if .System }}{{ .System }} {{ end }}
|
||||
{{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
|
||||
{{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||
})
|
||||
|
||||
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Model: "test-suffix",
|
||||
Modelfile: `FROM test
|
||||
TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
|
||||
{{- else }}{{ .Prompt }}
|
||||
{{- end }}"""`,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("prompt with suffix", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-suffix",
|
||||
Prompt: "def add(",
|
||||
Suffix: " return c",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("prompt without suffix", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-suffix",
|
||||
Prompt: "def add(",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("raw", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-system",
|
||||
Prompt: "Help me write tests.",
|
||||
Raw: true,
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
@@ -7,11 +7,14 @@ import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
func TestList(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
envconfig.LoadConfig()
|
||||
|
||||
|
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
@@ -272,6 +273,77 @@ func Test_Routes(t *testing.T) {
|
||||
assert.Equal(t, "library", retrieveResp.OwnedBy)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Embed Handler Empty Input",
|
||||
Method: http.MethodPost,
|
||||
Path: "/api/embed",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
embedReq := api.EmbedRequest{
|
||||
Model: "t-bone",
|
||||
Input: "",
|
||||
}
|
||||
jsonData, err := json.Marshal(embedReq)
|
||||
require.NoError(t, err)
|
||||
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
||||
},
|
||||
Expected: func(t *testing.T, resp *http.Response) {
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType != "application/json; charset=utf-8" {
|
||||
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
|
||||
}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var embedResp api.EmbedResponse
|
||||
err = json.Unmarshal(body, &embedResp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if embedResp.Model != "t-bone" {
|
||||
t.Fatalf("expected model t-bone, got %s", embedResp.Model)
|
||||
}
|
||||
|
||||
if embedResp.Embeddings == nil {
|
||||
t.Fatalf("expected embeddings to not be nil, got %v", embedResp.Embeddings)
|
||||
}
|
||||
|
||||
if len(embedResp.Embeddings) != 0 {
|
||||
t.Fatalf("expected embeddings to be empty, got %v", embedResp.Embeddings)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Embed Handler Invalid Input",
|
||||
Method: http.MethodPost,
|
||||
Path: "/api/embed",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
embedReq := api.EmbedRequest{
|
||||
Model: "t-bone",
|
||||
Input: 2,
|
||||
}
|
||||
jsonData, err := json.Marshal(embedReq)
|
||||
require.NoError(t, err)
|
||||
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
||||
},
|
||||
Expected: func(t *testing.T, resp *http.Response) {
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType != "application/json; charset=utf-8" {
|
||||
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
|
||||
}
|
||||
_, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected status code 400, got %d", resp.StatusCode)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
@@ -420,3 +492,38 @@ func TestShow(t *testing.T) {
|
||||
t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalize(t *testing.T) {
|
||||
type testCase struct {
|
||||
input []float32
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{input: []float32{1}},
|
||||
{input: []float32{0, 1, 2, 3}},
|
||||
{input: []float32{0.1, 0.2, 0.3}},
|
||||
{input: []float32{-0.1, 0.2, 0.3, -0.4}},
|
||||
{input: []float32{0, 0, 0}},
|
||||
}
|
||||
|
||||
isNormalized := func(vec []float32) (res bool) {
|
||||
sum := 0.0
|
||||
for _, v := range vec {
|
||||
sum += float64(v * v)
|
||||
}
|
||||
if math.Abs(sum-1) > 1e-6 {
|
||||
return sum == 0
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
normalized := normalize(tc.input)
|
||||
if !isNormalized(normalized) {
|
||||
t.Errorf("Vector %v is not normalized", tc.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -132,18 +132,11 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 {
|
||||
numParallel = 1
|
||||
slog.Warn("multimodal models don't support parallel requests yet")
|
||||
}
|
||||
// Keep NumCtx and numParallel in sync
|
||||
if numParallel > 1 {
|
||||
pending.opts.NumCtx = pending.origNumCtx * numParallel
|
||||
} else if strings.Contains(pending.model.Config.ModelFamily, "bert") {
|
||||
numParallel = runtime.NumCPU()
|
||||
}
|
||||
|
||||
for {
|
||||
cpus := s.getCpuFn()
|
||||
var systemMem gpu.GpuInfo
|
||||
if len(cpus) > 0 {
|
||||
systemMem = cpus[0]
|
||||
}
|
||||
var runnerToExpire *runnerRef
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded[pending.model.ModelPath]
|
||||
@@ -197,35 +190,15 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
break
|
||||
}
|
||||
|
||||
// Block attempting to load a model larger than system memory + GPU memory
|
||||
estimate := llm.EstimateGPULayers(gpus, ggml, pending.model.ProjectorPaths, pending.opts)
|
||||
maxSize := systemMem.FreeMemory
|
||||
for _, gpu := range gpus {
|
||||
if gpu.Library == "cpu" {
|
||||
continue
|
||||
}
|
||||
if loadedCount == 0 {
|
||||
// If no other models are loaded, set the limit based on what's available
|
||||
maxSize += gpu.FreeMemory
|
||||
} else {
|
||||
// Other models could be unloaded, favor total memory for limit
|
||||
maxSize += gpu.TotalMemory
|
||||
}
|
||||
}
|
||||
if estimate.TotalSize > maxSize {
|
||||
slog.Warn("model request too large for system", "requested", format.HumanBytes2(estimate.TotalSize), "system", format.HumanBytes2(maxSize))
|
||||
pending.errCh <- fmt.Errorf("requested model (%s) is too large for this system (%s)", format.HumanBytes2(estimate.TotalSize), format.HumanBytes2(maxSize))
|
||||
break
|
||||
}
|
||||
|
||||
// Evaluate if the model will fit in the available system memory, or if we should unload a model first
|
||||
if len(gpus) == 1 && gpus[0].Library == "cpu" {
|
||||
// simplifying assumption of defaultParallel when in CPU mode
|
||||
if numParallel <= 0 {
|
||||
numParallel = defaultParallel
|
||||
pending.opts.NumCtx = pending.origNumCtx * numParallel
|
||||
}
|
||||
|
||||
pending.opts.NumCtx = pending.origNumCtx * numParallel
|
||||
|
||||
if loadedCount == 0 {
|
||||
slog.Debug("cpu mode with first model, loading")
|
||||
s.loadFn(pending, ggml, gpus, numParallel)
|
||||
|
@@ -94,7 +94,7 @@ func TestLoad(t *testing.T) {
|
||||
require.Len(t, s.expiredCh, 1)
|
||||
}
|
||||
|
||||
type bundle struct {
|
||||
type reqBundle struct {
|
||||
ctx context.Context //nolint:containedctx
|
||||
ctxDone func()
|
||||
srv *mockLlm
|
||||
@@ -102,13 +102,13 @@ type bundle struct {
|
||||
ggml *llm.GGML
|
||||
}
|
||||
|
||||
func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
func (scenario *reqBundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
return scenario.srv, nil
|
||||
}
|
||||
|
||||
func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle {
|
||||
scenario := &bundle{}
|
||||
scenario.ctx, scenario.ctxDone = context.WithCancel(ctx)
|
||||
func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64, duration *api.Duration) *reqBundle {
|
||||
b := &reqBundle{}
|
||||
b.ctx, b.ctxDone = context.WithCancel(ctx)
|
||||
t.Helper()
|
||||
|
||||
f, err := os.CreateTemp(t.TempDir(), modelName)
|
||||
@@ -135,124 +135,154 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
|
||||
|
||||
fname := f.Name()
|
||||
model := &Model{Name: modelName, ModelPath: fname}
|
||||
scenario.ggml, err = llm.LoadModel(model.ModelPath, 0)
|
||||
b.ggml, err = llm.LoadModel(model.ModelPath, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
scenario.req = &LlmRequest{
|
||||
ctx: scenario.ctx,
|
||||
if duration == nil {
|
||||
duration = &api.Duration{Duration: 5 * time.Millisecond}
|
||||
}
|
||||
b.req = &LlmRequest{
|
||||
ctx: b.ctx,
|
||||
model: model,
|
||||
opts: api.DefaultOptions(),
|
||||
sessionDuration: &api.Duration{Duration: 5 * time.Millisecond},
|
||||
sessionDuration: duration,
|
||||
successCh: make(chan *runnerRef, 1),
|
||||
errCh: make(chan error, 1),
|
||||
}
|
||||
scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
|
||||
return scenario
|
||||
b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
|
||||
return b
|
||||
}
|
||||
|
||||
func TestRequests(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
func getGpuFn() gpu.GpuInfoList {
|
||||
g := gpu.GpuInfo{Library: "metal"}
|
||||
g.TotalMemory = 24 * format.GigaByte
|
||||
g.FreeMemory = 12 * format.GigaByte
|
||||
return []gpu.GpuInfo{g}
|
||||
}
|
||||
|
||||
func getCpuFn() gpu.GpuInfoList {
|
||||
g := gpu.GpuInfo{Library: "cpu"}
|
||||
g.TotalMemory = 32 * format.GigaByte
|
||||
g.FreeMemory = 26 * format.GigaByte
|
||||
return []gpu.GpuInfo{g}
|
||||
}
|
||||
|
||||
func TestRequestsSameModelSameRequest(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
// Same model, same request
|
||||
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
|
||||
scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
|
||||
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
|
||||
scenario1b.req.model = scenario1a.req.model
|
||||
scenario1b.ggml = scenario1a.ggml
|
||||
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
|
||||
|
||||
// simple reload of same model
|
||||
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
|
||||
tmpModel := *scenario1a.req.model
|
||||
scenario2a.req.model = &tmpModel
|
||||
scenario2a.ggml = scenario1a.ggml
|
||||
scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
|
||||
|
||||
// Multiple loaded models
|
||||
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
|
||||
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
|
||||
scenario3c := newScenario(t, ctx, "ollama-model-4a", 30)
|
||||
scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed
|
||||
scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = func() gpu.GpuInfoList {
|
||||
g := gpu.GpuInfo{Library: "metal"}
|
||||
g.TotalMemory = 24 * format.GigaByte
|
||||
g.FreeMemory = 12 * format.GigaByte
|
||||
return []gpu.GpuInfo{g}
|
||||
}
|
||||
s.getCpuFn = func() gpu.GpuInfoList {
|
||||
g := gpu.GpuInfo{Library: "cpu"}
|
||||
g.TotalMemory = 32 * format.GigaByte
|
||||
g.FreeMemory = 26 * format.GigaByte
|
||||
return []gpu.GpuInfo{g}
|
||||
}
|
||||
s.newServerFn = scenario1a.newServer
|
||||
slog.Info("scenario1a")
|
||||
s.pendingReqCh <- scenario1a.req
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getCpuFn = getCpuFn
|
||||
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
|
||||
b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0})
|
||||
b.req.model = a.req.model
|
||||
b.ggml = a.ggml
|
||||
|
||||
s.newServerFn = a.newServer
|
||||
slog.Info("a")
|
||||
s.pendingReqCh <- a.req
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
s.Run(ctx)
|
||||
select {
|
||||
case resp := <-scenario1a.req.successCh:
|
||||
require.Equal(t, resp.llama, scenario1a.srv)
|
||||
case resp := <-a.req.successCh:
|
||||
require.Equal(t, resp.llama, a.srv)
|
||||
require.Empty(t, s.pendingReqCh)
|
||||
require.Empty(t, scenario1a.req.errCh)
|
||||
case err := <-scenario1a.req.errCh:
|
||||
require.Empty(t, a.req.errCh)
|
||||
case err := <-a.req.errCh:
|
||||
t.Fatal(err.Error())
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
// Same runner as first request due to not needing a reload
|
||||
s.newServerFn = scenario1b.newServer
|
||||
slog.Info("scenario1b")
|
||||
s.pendingReqCh <- scenario1b.req
|
||||
s.newServerFn = b.newServer
|
||||
slog.Info("b")
|
||||
s.pendingReqCh <- b.req
|
||||
select {
|
||||
case resp := <-scenario1b.req.successCh:
|
||||
require.Equal(t, resp.llama, scenario1a.srv)
|
||||
case resp := <-b.req.successCh:
|
||||
require.Equal(t, resp.llama, a.srv)
|
||||
require.Empty(t, s.pendingReqCh)
|
||||
require.Empty(t, scenario1b.req.errCh)
|
||||
case err := <-scenario1b.req.errCh:
|
||||
require.Empty(t, b.req.errCh)
|
||||
case err := <-b.req.errCh:
|
||||
t.Fatal(err.Error())
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestsSimpleReloadSameModel(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer done()
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getCpuFn = getCpuFn
|
||||
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
|
||||
b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond})
|
||||
tmpModel := *a.req.model
|
||||
b.req.model = &tmpModel
|
||||
b.ggml = a.ggml
|
||||
|
||||
s.newServerFn = a.newServer
|
||||
slog.Info("a")
|
||||
s.pendingReqCh <- a.req
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
s.Run(ctx)
|
||||
select {
|
||||
case resp := <-a.req.successCh:
|
||||
require.Equal(t, resp.llama, a.srv)
|
||||
require.Empty(t, s.pendingReqCh)
|
||||
require.Empty(t, a.req.errCh)
|
||||
case err := <-a.req.errCh:
|
||||
t.Fatal(err.Error())
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
|
||||
// Trigger a reload
|
||||
s.newServerFn = scenario2a.newServer
|
||||
scenario2a.req.model.AdapterPaths = []string{"new"}
|
||||
slog.Info("scenario2a")
|
||||
s.pendingReqCh <- scenario2a.req
|
||||
s.newServerFn = b.newServer
|
||||
b.req.model.AdapterPaths = []string{"new"}
|
||||
slog.Info("b")
|
||||
s.pendingReqCh <- b.req
|
||||
// finish first two requests, so model can reload
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
scenario1a.ctxDone()
|
||||
scenario1b.ctxDone()
|
||||
a.ctxDone()
|
||||
select {
|
||||
case resp := <-scenario2a.req.successCh:
|
||||
require.Equal(t, resp.llama, scenario2a.srv)
|
||||
case resp := <-b.req.successCh:
|
||||
require.Equal(t, resp.llama, b.srv)
|
||||
require.Empty(t, s.pendingReqCh)
|
||||
require.Empty(t, scenario2a.req.errCh)
|
||||
case err := <-scenario2a.req.errCh:
|
||||
require.Empty(t, b.req.errCh)
|
||||
case err := <-b.req.errCh:
|
||||
t.Fatal(err.Error())
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestsMultipleLoadedModels(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer done()
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getCpuFn = getCpuFn
|
||||
|
||||
// Multiple loaded models
|
||||
a := newScenarioRequest(t, ctx, "ollama-model-3a", 1*format.GigaByte, nil)
|
||||
b := newScenarioRequest(t, ctx, "ollama-model-3b", 24*format.GigaByte, nil)
|
||||
c := newScenarioRequest(t, ctx, "ollama-model-4a", 30, nil)
|
||||
c.req.opts.NumGPU = 0 // CPU load, will be allowed
|
||||
d := newScenarioRequest(t, ctx, "ollama-model-3c", 30, nil) // Needs prior unloaded
|
||||
|
||||
envconfig.MaxRunners = 1
|
||||
s.newServerFn = scenario3a.newServer
|
||||
slog.Info("scenario3a")
|
||||
s.pendingReqCh <- scenario3a.req
|
||||
// finish prior request, so new model can load
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
scenario2a.ctxDone()
|
||||
s.newServerFn = a.newServer
|
||||
slog.Info("a")
|
||||
s.pendingReqCh <- a.req
|
||||
s.Run(ctx)
|
||||
select {
|
||||
case resp := <-scenario3a.req.successCh:
|
||||
require.Equal(t, resp.llama, scenario3a.srv)
|
||||
case resp := <-a.req.successCh:
|
||||
require.Equal(t, resp.llama, a.srv)
|
||||
require.Empty(t, s.pendingReqCh)
|
||||
require.Empty(t, scenario3a.req.errCh)
|
||||
case err := <-scenario3a.req.errCh:
|
||||
require.Empty(t, a.req.errCh)
|
||||
case err := <-a.req.errCh:
|
||||
t.Fatal(err.Error())
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout")
|
||||
@@ -262,15 +292,15 @@ func TestRequests(t *testing.T) {
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
envconfig.MaxRunners = 0
|
||||
s.newServerFn = scenario3b.newServer
|
||||
slog.Info("scenario3b")
|
||||
s.pendingReqCh <- scenario3b.req
|
||||
s.newServerFn = b.newServer
|
||||
slog.Info("b")
|
||||
s.pendingReqCh <- b.req
|
||||
select {
|
||||
case resp := <-scenario3b.req.successCh:
|
||||
require.Equal(t, resp.llama, scenario3b.srv)
|
||||
case resp := <-b.req.successCh:
|
||||
require.Equal(t, resp.llama, b.srv)
|
||||
require.Empty(t, s.pendingReqCh)
|
||||
require.Empty(t, scenario3b.req.errCh)
|
||||
case err := <-scenario3b.req.errCh:
|
||||
require.Empty(t, b.req.errCh)
|
||||
case err := <-b.req.errCh:
|
||||
t.Fatal(err.Error())
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout")
|
||||
@@ -280,15 +310,15 @@ func TestRequests(t *testing.T) {
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// This is a CPU load with NumGPU = 0 so it should load
|
||||
s.newServerFn = scenario3c.newServer
|
||||
slog.Info("scenario3c")
|
||||
s.pendingReqCh <- scenario3c.req
|
||||
s.newServerFn = c.newServer
|
||||
slog.Info("c")
|
||||
s.pendingReqCh <- c.req
|
||||
select {
|
||||
case resp := <-scenario3c.req.successCh:
|
||||
require.Equal(t, resp.llama, scenario3c.srv)
|
||||
case resp := <-c.req.successCh:
|
||||
require.Equal(t, resp.llama, c.srv)
|
||||
require.Empty(t, s.pendingReqCh)
|
||||
require.Empty(t, scenario3c.req.errCh)
|
||||
case err := <-scenario3c.req.errCh:
|
||||
require.Empty(t, c.req.errCh)
|
||||
case err := <-c.req.errCh:
|
||||
t.Fatal(err.Error())
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout")
|
||||
@@ -298,25 +328,25 @@ func TestRequests(t *testing.T) {
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
// Try to load a model that wont fit
|
||||
s.newServerFn = scenario3d.newServer
|
||||
slog.Info("scenario3d")
|
||||
s.newServerFn = d.newServer
|
||||
slog.Info("d")
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 3)
|
||||
s.loadedMu.Unlock()
|
||||
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
|
||||
a.ctxDone() // Won't help since this one isn't big enough to make room
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
s.pendingReqCh <- scenario3d.req
|
||||
s.pendingReqCh <- d.req
|
||||
// finish prior request, so new model can load
|
||||
time.Sleep(6 * time.Millisecond)
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 2)
|
||||
s.loadedMu.Unlock()
|
||||
scenario3b.ctxDone()
|
||||
b.ctxDone()
|
||||
select {
|
||||
case resp := <-scenario3d.req.successCh:
|
||||
require.Equal(t, resp.llama, scenario3d.srv)
|
||||
case resp := <-d.req.successCh:
|
||||
require.Equal(t, resp.llama, d.srv)
|
||||
require.Empty(t, s.pendingReqCh)
|
||||
require.Empty(t, scenario3d.req.errCh)
|
||||
require.Empty(t, d.req.errCh)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
@@ -329,26 +359,19 @@ func TestGetRunner(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
|
||||
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
|
||||
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
|
||||
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
|
||||
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
|
||||
scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
|
||||
a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond})
|
||||
b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond})
|
||||
c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond})
|
||||
envconfig.MaxQueuedRequests = 1
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = func() gpu.GpuInfoList {
|
||||
g := gpu.GpuInfo{Library: "metal"}
|
||||
g.TotalMemory = 24 * format.GigaByte
|
||||
g.FreeMemory = 12 * format.GigaByte
|
||||
return []gpu.GpuInfo{g}
|
||||
}
|
||||
s.newServerFn = scenario1a.newServer
|
||||
slog.Info("scenario1a")
|
||||
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getCpuFn = getCpuFn
|
||||
s.newServerFn = a.newServer
|
||||
slog.Info("a")
|
||||
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
slog.Info("scenario1b")
|
||||
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
|
||||
slog.Info("b")
|
||||
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
require.Empty(t, successCh1b)
|
||||
require.Len(t, errCh1b, 1)
|
||||
@@ -357,22 +380,24 @@ func TestGetRunner(t *testing.T) {
|
||||
s.Run(ctx)
|
||||
select {
|
||||
case resp := <-successCh1a:
|
||||
require.Equal(t, resp.llama, scenario1a.srv)
|
||||
require.Equal(t, resp.llama, a.srv)
|
||||
require.Empty(t, s.pendingReqCh)
|
||||
require.Empty(t, errCh1a)
|
||||
case err := <-errCh1a:
|
||||
t.Fatal(err.Error())
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
scenario1a.ctxDone()
|
||||
a.ctxDone() // Set "a" model to idle so it can unload
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 1)
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
scenario1c.req.model.ModelPath = "bad path"
|
||||
slog.Info("scenario1c")
|
||||
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
|
||||
c.req.model.ModelPath = "bad path"
|
||||
slog.Info("c")
|
||||
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration)
|
||||
// Starts in pending channel, then should be quickly processsed to return an error
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
time.Sleep(20 * time.Millisecond) // Long enough for the "a" model to expire and unload
|
||||
require.Empty(t, successCh1c)
|
||||
s.loadedMu.Lock()
|
||||
require.Empty(t, s.loaded)
|
||||
@@ -380,7 +405,7 @@ func TestGetRunner(t *testing.T) {
|
||||
require.Len(t, errCh1c, 1)
|
||||
err = <-errCh1c
|
||||
require.Contains(t, err.Error(), "bad path")
|
||||
scenario1b.ctxDone()
|
||||
b.ctxDone()
|
||||
}
|
||||
|
||||
// TODO - add one scenario that triggers the bogus finished event with positive ref count
|
||||
@@ -389,7 +414,7 @@ func TestPrematureExpired(t *testing.T) {
|
||||
defer done()
|
||||
|
||||
// Same model, same request
|
||||
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
|
||||
scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil)
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = func() gpu.GpuInfoList {
|
||||
g := gpu.GpuInfo{Library: "metal"}
|
||||
@@ -411,6 +436,8 @@ func TestPrematureExpired(t *testing.T) {
|
||||
s.loadedMu.Unlock()
|
||||
slog.Info("sending premature expired event now")
|
||||
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
|
||||
case err := <-errCh1a:
|
||||
t.Fatal(err.Error())
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
@@ -446,6 +473,8 @@ func TestUseLoadedRunner(t *testing.T) {
|
||||
select {
|
||||
case success := <-req.successCh:
|
||||
require.Equal(t, r1, success)
|
||||
case err := <-req.errCh:
|
||||
t.Fatal(err.Error())
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
@@ -625,8 +654,7 @@ func TestAlreadyCanceled(t *testing.T) {
|
||||
defer done()
|
||||
dctx, done2 := context.WithCancel(ctx)
|
||||
done2()
|
||||
scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
|
||||
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
|
||||
scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0})
|
||||
s := InitScheduler(ctx)
|
||||
slog.Info("scenario1a")
|
||||
s.pendingReqCh <- scenario1a.req
|
||||
@@ -642,8 +670,8 @@ type mockLlm struct {
|
||||
pingResp error
|
||||
waitResp error
|
||||
completionResp error
|
||||
embeddingResp []float64
|
||||
embeddingRespErr error
|
||||
embedResp [][]float32
|
||||
embedRespErr error
|
||||
tokenizeResp []int
|
||||
tokenizeRespErr error
|
||||
detokenizeResp string
|
||||
@@ -660,8 +688,8 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
|
||||
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
return s.completionResp
|
||||
}
|
||||
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
||||
return s.embeddingResp, s.embeddingRespErr
|
||||
func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) {
|
||||
return s.embedResp, s.embedRespErr
|
||||
}
|
||||
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
return s.tokenizeResp, s.tokenizeRespErr
|
||||
|
67
server/testdata/tools/command-r-plus.gotmpl
vendored
Normal file
67
server/testdata/tools/command-r-plus.gotmpl
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
{{- if or .Tools .System }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>
|
||||
{{- if .Tools }}# Safety Preamble
|
||||
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
|
||||
|
||||
# System Preamble
|
||||
## Basic Rules
|
||||
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
|
||||
|
||||
{{ if .System }}# User Preamble
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
|
||||
## Available Tools
|
||||
Here is a list of tools that you have available to you:
|
||||
{{- range .Tools }}
|
||||
|
||||
```python
|
||||
def {{ .Function.Name }}(
|
||||
{{- range $name, $property := .Function.Parameters.Properties }}{{ $name }}: {{ $property.Type }}, {{ end }}) -> List[Dict]:
|
||||
"""{{ .Function.Description }}
|
||||
|
||||
{{- if .Function.Parameters.Properties }}
|
||||
|
||||
Args:
|
||||
{{- range $name, $property := .Function.Parameters.Properties }}
|
||||
{{ $name }} ({{ $property.Type }}): {{ $property.Description }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
"""
|
||||
pass
|
||||
```
|
||||
{{- end }}
|
||||
{{- else if .System }}{{ .System }}
|
||||
{{- end }}<|END_OF_TURN_TOKEN|>
|
||||
{{- end }}
|
||||
{{- range .Messages }}
|
||||
{{- if eq .Role "system" }}
|
||||
{{- continue }}
|
||||
{{- end }}<|START_OF_TURN_TOKEN|>
|
||||
{{- if eq .Role "user" }}<|USER_TOKEN|>{{ .Content }}
|
||||
{{- else if eq .Role "assistant" }}<|CHATBOT_TOKEN|>
|
||||
{{- if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }}
|
||||
Action: ```json
|
||||
[
|
||||
{{- range .ToolCalls }}
|
||||
{
|
||||
"tool_name": "{{ .Function.Name }}",
|
||||
"parameters": {{ .Function.Arguments }}
|
||||
}
|
||||
{{- end }}
|
||||
]```
|
||||
{{ continue }}
|
||||
{{ end }}
|
||||
{{- else if eq .Role "tool" }}<|SYSTEM_TOKEN|><results>
|
||||
{{ .Content }}</results>
|
||||
{{- end }}<|END_OF_TURN_TOKEN|>
|
||||
{{- end }}
|
||||
{{- if .Tools }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"tool_name": title of the tool in the specification,
|
||||
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
|
||||
}
|
||||
]```
|
||||
{{- end }}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
39
server/testdata/tools/command-r-plus.out
vendored
Normal file
39
server/testdata/tools/command-r-plus.out
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble
|
||||
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
|
||||
|
||||
# System Preamble
|
||||
## Basic Rules
|
||||
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
|
||||
|
||||
# User Preamble
|
||||
You are a knowledgable assistant. You can answer questions and perform tasks.
|
||||
|
||||
## Available Tools
|
||||
Here is a list of tools that you have available to you:
|
||||
|
||||
```python
|
||||
def get_current_weather(format: string, location: string, ) -> List[Dict]:
|
||||
"""Get the current weather
|
||||
|
||||
Args:
|
||||
format (string): The temperature unit to use. Infer this from the users location.
|
||||
location (string): The city and state, e.g. San Francisco, CA
|
||||
"""
|
||||
pass
|
||||
```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in Paris?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
||||
Action: ```json
|
||||
[
|
||||
{
|
||||
"tool_name": "get_current_weather",
|
||||
"parameters": {"format":"celsius","location":"Paris, France"}
|
||||
}
|
||||
]```
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><results>
|
||||
22</results><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>The current temperature in Paris, France is 22 degrees Celsius.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in San Francisco and Toronto?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"tool_name": title of the tool in the specification,
|
||||
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
|
||||
}
|
||||
]```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
31
server/testdata/tools/firefunction.gotmpl
vendored
Normal file
31
server/testdata/tools/firefunction.gotmpl
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
|
||||
{{- if .System }}
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
In addition to plain text responses, you can chose to call one or more of the provided functions.
|
||||
|
||||
Use the following rule to decide when to call a function:
|
||||
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
|
||||
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
|
||||
|
||||
If you decide to call functions:
|
||||
* prefix function calls with functools marker (no closing marker required)
|
||||
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
|
||||
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
|
||||
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
|
||||
* make sure you pick the right functions that match the user intent
|
||||
|
||||
Available functions as JSON spec:
|
||||
{{- if .Tools }}
|
||||
{{ .Tools }}
|
||||
{{- end }}<|eot_id|>
|
||||
{{- end }}
|
||||
{{- range .Messages }}<|start_header_id|>
|
||||
{{- if or (eq .Role "user") (eq .Role "assistant") (eq .Role "tool") }}{{ .Role }}
|
||||
{{- end }}<|end_header_id|>
|
||||
{{- if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }} functools[
|
||||
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }}
|
||||
{{- end }}]
|
||||
{{- end }}<|eot_id|>
|
||||
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
17
server/testdata/tools/firefunction.out
vendored
Normal file
17
server/testdata/tools/firefunction.out
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
You are a knowledgable assistant. You can answer questions and perform tasks.
|
||||
In addition to plain text responses, you can chose to call one or more of the provided functions.
|
||||
|
||||
Use the following rule to decide when to call a function:
|
||||
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
|
||||
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
|
||||
|
||||
If you decide to call functions:
|
||||
* prefix function calls with functools marker (no closing marker required)
|
||||
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
|
||||
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
|
||||
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
|
||||
* make sure you pick the right functions that match the user intent
|
||||
|
||||
Available functions as JSON spec:
|
||||
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]<|eot_id|><|start_header_id|><|end_header_id|>You are a knowledgable assistant. You can answer questions and perform tasks.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> functools[{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]<|eot_id|><|start_header_id|>tool<|end_header_id|>22<|eot_id|><|start_header_id|>assistant<|end_header_id|>The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
43
server/testdata/tools/llama3-groq-tool-use.gotmpl
vendored
Normal file
43
server/testdata/tools/llama3-groq-tool-use.gotmpl
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
{{- if .Messages }}
|
||||
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
{{ .System }}
|
||||
{{- if .Tools }} You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||
<tool_call>
|
||||
{"name": <function-name>,"arguments": <args-dict>}
|
||||
</tool_call>
|
||||
|
||||
Here are the available tools:
|
||||
<tools>
|
||||
{{- range .Tools }} {{ .Function }}
|
||||
{{- end }} </tools>
|
||||
{{- end }}
|
||||
{{- end }}<|eot_id|>
|
||||
{{- range .Messages }}
|
||||
{{- if ne .Role "system" }}<|start_header_id|>{{ .Role }}<|end_header_id|>
|
||||
|
||||
{{ if eq .Role "user" }}{{ .Content }}
|
||||
{{- else if eq .Role "assistant" }}
|
||||
{{- if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }}<tool_call>
|
||||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{- end }}
|
||||
</tool_call>
|
||||
{{- end }}
|
||||
{{- else if eq .Role "tool" }}<tool_response>
|
||||
{{ .Content }}
|
||||
</tool_response>
|
||||
{{- end }}<|eot_id|>
|
||||
{{- end }}
|
||||
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ else }}
|
||||
{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
||||
|
||||
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ end }}{{ .Response }}
|
||||
{{- if .Response }}<|eot_id|>
|
||||
{{- end }}
|
24
server/testdata/tools/llama3-groq-tool-use.out
vendored
Normal file
24
server/testdata/tools/llama3-groq-tool-use.out
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
You are a knowledgable assistant. You can answer questions and perform tasks. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||
<tool_call>
|
||||
{"name": <function-name>,"arguments": <args-dict>}
|
||||
</tool_call>
|
||||
|
||||
Here are the available tools:
|
||||
<tools> {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}} </tools><|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
<tool_call>
|
||||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
|
||||
</tool_call><|eot_id|><|start_header_id|>tool<|end_header_id|>
|
||||
|
||||
<tool_response>
|
||||
22
|
||||
</tool_response><|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
39
server/testdata/tools/messages.json
vendored
Normal file
39
server/testdata/tools/messages.json
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a knowledgable assistant. You can answer questions and perform tasks."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like today in Paris?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "89a1e453-0bce-4de3-a456-c54bed09c520",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": {
|
||||
"location": "Paris, France",
|
||||
"format": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "89a1e453-0bce-4de3-a456-c54bed09c520",
|
||||
"content": "22"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The current temperature in Paris, France is 22 degrees Celsius."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like today in San Francisco and Toronto?"
|
||||
}
|
||||
]
|
15
server/testdata/tools/mistral.gotmpl
vendored
Normal file
15
server/testdata/tools/mistral.gotmpl
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
{{- range $index, $_ := .Messages }}
|
||||
{{- if eq .Role "user" }}
|
||||
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]
|
||||
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
|
||||
|
||||
{{ end }}{{ .Content }}[/INST]
|
||||
{{- else if eq .Role "assistant" }}
|
||||
{{- if .Content }} {{ .Content }}</s>
|
||||
{{- else if .ToolCalls }}[TOOL_CALLS] [
|
||||
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{- end }}]</s>
|
||||
{{- end }}
|
||||
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
|
||||
{{- end }}
|
||||
{{- end }}
|
3
server/testdata/tools/mistral.out
vendored
Normal file
3
server/testdata/tools/mistral.out
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[INST] What's the weather like today in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]</s>[TOOL_RESULTS] {"content": 22}[/TOOL_RESULTS] The current temperature in Paris, France is 22 degrees Celsius.</s>[AVAILABLE_TOOLS] [{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}][/AVAILABLE_TOOLS][INST] You are a knowledgable assistant. You can answer questions and perform tasks.
|
||||
|
||||
What's the weather like today in San Francisco and Toronto?[/INST]
|
30
server/testdata/tools/tools.json
vendored
Normal file
30
server/testdata/tools/tools.json
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"celsius",
|
||||
"fahrenheit"
|
||||
],
|
||||
"description": "The temperature unit to use. Infer this from the users location."
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"location",
|
||||
"format"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
45
server/testdata/tools/xlam.gotmpl
vendored
Normal file
45
server/testdata/tools/xlam.gotmpl
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
{{- if .System }}{{ .System }}
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- if eq .Role "user" }}### Instruction:
|
||||
{{- if and $.Tools (le (len (slice $.Messages $i)) 2) }}
|
||||
[BEGIN OF TASK INSTRUCTION]
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the functions can be used, point it out and refuse to answer.
|
||||
If the given question lacks the parameters required by the function, also point it out.
|
||||
[END OF TASK INSTRUCTION]
|
||||
|
||||
[BEGIN OF AVAILABLE TOOLS]
|
||||
{{ $.Tools }}
|
||||
[END OF AVAILABLE TOOLS]
|
||||
|
||||
[BEGIN OF FORMAT INSTRUCTION]
|
||||
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
|
||||
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
|
||||
```
|
||||
{
|
||||
"tool_calls": [
|
||||
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
|
||||
... (more tool calls as required)
|
||||
]
|
||||
}
|
||||
```
|
||||
[END OF FORMAT INSTRUCTION]
|
||||
|
||||
[BEGIN OF QUERY]
|
||||
{{ .Content }}
|
||||
[END OF QUERY]
|
||||
|
||||
|
||||
{{ else }}
|
||||
{{ .Content }}
|
||||
{{ end }}
|
||||
{{- else if .ToolCalls }}### Response:
|
||||
{"tool_calls": [{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{ end }}]}
|
||||
<|EOT|>
|
||||
{{ else if eq .Role "assistant" }}### Response:
|
||||
{{ .Content }}
|
||||
<|EOT|>
|
||||
{{ end }}
|
||||
{{- end }}### Response:
|
40
server/testdata/tools/xlam.out
vendored
Normal file
40
server/testdata/tools/xlam.out
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
You are a knowledgable assistant. You can answer questions and perform tasks.
|
||||
### Instruction:
|
||||
What's the weather like today in Paris?
|
||||
### Response:
|
||||
{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]}
|
||||
<|EOT|>
|
||||
### Response:
|
||||
The current temperature in Paris, France is 22 degrees Celsius.
|
||||
<|EOT|>
|
||||
### Instruction:
|
||||
[BEGIN OF TASK INSTRUCTION]
|
||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the functions can be used, point it out and refuse to answer.
|
||||
If the given question lacks the parameters required by the function, also point it out.
|
||||
[END OF TASK INSTRUCTION]
|
||||
|
||||
[BEGIN OF AVAILABLE TOOLS]
|
||||
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]
|
||||
[END OF AVAILABLE TOOLS]
|
||||
|
||||
[BEGIN OF FORMAT INSTRUCTION]
|
||||
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
|
||||
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
|
||||
```
|
||||
{
|
||||
"tool_calls": [
|
||||
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
|
||||
... (more tool calls as required)
|
||||
]
|
||||
}
|
||||
```
|
||||
[END OF FORMAT INSTRUCTION]
|
||||
|
||||
[BEGIN OF QUERY]
|
||||
What's the weather like today in San Francisco and Toronto?
|
||||
[END OF QUERY]
|
||||
|
||||
|
||||
### Response:
|
8
template/alfred.json
Normal file
8
template/alfred.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"stop": [
|
||||
"<start_system>",
|
||||
"<end_message>",
|
||||
"<start_user>",
|
||||
"<start_assistant>"
|
||||
]
|
||||
}
|
@@ -4,4 +4,5 @@
|
||||
{{ .Prompt }}
|
||||
|
||||
{{ end }}### Response:
|
||||
{{ .Response }}
|
||||
{{ .Response }}
|
||||
|
||||
|
6
template/alpaca.json
Normal file
6
template/alpaca.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"stop": [
|
||||
"### Instruction:",
|
||||
"### Response"
|
||||
]
|
||||
}
|
@@ -3,4 +3,4 @@
|
||||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||
{{ .Prompt }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ .Response }}<|im_end|>
|
||||
{{ .Response }}<|im_end|>
|
||||
|
6
template/chatml.json
Normal file
6
template/chatml.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"stop": [
|
||||
"<|im_start|>",
|
||||
"<|im_end|>"
|
||||
]
|
||||
}
|
@@ -2,4 +2,5 @@
|
||||
|
||||
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
|
||||
|
||||
{{ end }}Assistant: <|begin_of_text|>{{ .Response }}
|
||||
{{ end }}Assistant: {{ .Response }}
|
||||
|
||||
|
8
template/chatqa.json
Normal file
8
template/chatqa.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"stop": [
|
||||
"System:",
|
||||
"User:",
|
||||
"Assistant:",
|
||||
"<|begin_of_text|>"
|
||||
]
|
||||
}
|
@@ -1,8 +1,10 @@
|
||||
{{ if .System }} Source: system
|
||||
{{ if .System }}Source: system
|
||||
|
||||
{{ .System }} <step>{{ end }} Source: user
|
||||
{{ .System }} <step> {{ end }}Source: user
|
||||
|
||||
{{ .Prompt }} <step> Source: assistant
|
||||
{{- if not .Response }}
|
||||
Destination: user
|
||||
{{- end }}
|
||||
|
||||
{{ .Response }}<step>
|
||||
{{ .Response }} <step>
|
7
template/codellama-70b-instruct.json
Normal file
7
template/codellama-70b-instruct.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"stop": [
|
||||
"Source:",
|
||||
"Destination:",
|
||||
"<step>"
|
||||
]
|
||||
}
|
@@ -1,3 +1,5 @@
|
||||
{{ if .System }}{{ .System }}
|
||||
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
|
||||
{{ end }}Assistant: {{ .Response }}
|
||||
{{ if .System }}System: {{ .System }}
|
||||
{{ end }}{{ if .Prompt }}User:
|
||||
{{ .Prompt }}
|
||||
{{ end }}Falcon:
|
||||
{{ .Response }}
|
||||
|
6
template/falcon-instruct.json
Normal file
6
template/falcon-instruct.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"stop": [
|
||||
"User:",
|
||||
"Assistant:"
|
||||
]
|
||||
}
|
@@ -1,4 +1,5 @@
|
||||
<start_of_turn>user
|
||||
{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}<end_of_turn>
|
||||
{{ if .System }}{{ .System }}
|
||||
{{ end }}{{ .Prompt }}<end_of_turn>
|
||||
<start_of_turn>model
|
||||
{{ .Response }}<end_of_turn>
|
||||
{{ .Response }}<end_of_turn>
|
||||
|
6
template/gemma-instruct.json
Normal file
6
template/gemma-instruct.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"stop": [
|
||||
"<start_of_turn>",
|
||||
"<end_of_turn>"
|
||||
]
|
||||
}
|
@@ -1,9 +1,9 @@
|
||||
{{ if .System }}
|
||||
System:
|
||||
{{ if .System }}System:
|
||||
{{ .System }}
|
||||
|
||||
{{ end }}{{ if .Prompt }}Question:
|
||||
{{ .Prompt }}
|
||||
|
||||
{{ end }}Answer:
|
||||
{{ .Response }}
|
||||
{{ .Response }}
|
||||
|
||||
|
7
template/granite-instruct.json
Normal file
7
template/granite-instruct.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"stop": [
|
||||
"System:",
|
||||
"Question:",
|
||||
"Answer:"
|
||||
]
|
||||
}
|
@@ -1,3 +1,6 @@
|
||||
[INST] <<SYS>>{{ .System }}<</SYS>>
|
||||
[INST] <<SYS>>
|
||||
{{- if .System }}
|
||||
{{ .System }}
|
||||
{{ end }}<</SYS>>
|
||||
|
||||
{{ .Prompt }} [/INST] {{ .Response }}
|
||||
{{ .Prompt }} [/INST] {{ .Response }}</s><s>
|
8
template/llama2-chat.json
Normal file
8
template/llama2-chat.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"stop": [
|
||||
"[INST]",
|
||||
"[/INST]",
|
||||
"<<SYS>>",
|
||||
"<</SYS>>"
|
||||
]
|
||||
}
|
7
template/llama3-instruct.json
Normal file
7
template/llama3-instruct.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"stop": [
|
||||
"<|start_header_id|>",
|
||||
"<|end_header_id|>",
|
||||
"<|eot_id|>"
|
||||
]
|
||||
}
|
@@ -4,4 +4,5 @@
|
||||
{{ .Prompt }}
|
||||
|
||||
{{ end }}@@ Response
|
||||
{{ .Response }}
|
||||
{{ .Response }}
|
||||
|
||||
|
6
template/magicoder.json
Normal file
6
template/magicoder.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"stop": [
|
||||
"@@ Instruction",
|
||||
"@@ Response"
|
||||
]
|
||||
}
|
@@ -1,6 +1,3 @@
|
||||
{{ if .System }}<|im_start|>system
|
||||
{{ .System }}<|im_end|>
|
||||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||
{{ .Prompt }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ .Response }}<|im_end|>
|
||||
[INST] {{ if .System }}{{ .System }}
|
||||
|
||||
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}</s>
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user