Compare commits
57 Commits
jyan/parse
...
roy-embed-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2647a0e443 | ||
|
|
ec4c35fe99 | ||
|
|
f5e3939220 | ||
|
|
ae27d9dcfd | ||
|
|
37096790a7 | ||
|
|
997c903884 | ||
|
|
c8af3c2d96 | ||
|
|
455e61170d | ||
|
|
4de1370a9d | ||
|
|
bbf8f102ee | ||
|
|
bb46bbcf5e | ||
|
|
ac33aa7d37 | ||
|
|
a6cd8f6169 | ||
|
|
c78089263a | ||
|
|
3e5ea035d5 | ||
|
|
5d604eec5b | ||
|
|
db0968f30c | ||
|
|
9b60a038e5 | ||
|
|
83a0cb8d88 | ||
|
|
c0648233f2 | ||
|
|
d835368eb8 | ||
|
|
5784c05397 | ||
|
|
f14aa5435d | ||
|
|
f8fedbda20 | ||
|
|
b3e5491e41 | ||
|
|
cc269ba094 | ||
|
|
a3c20e3f18 | ||
|
|
80ee9b5e47 | ||
|
|
5534f2cc6a | ||
|
|
d321297d8a | ||
|
|
06e5d74e34 | ||
|
|
5d707e6fd5 | ||
|
|
283948c83b | ||
|
|
1475eab95f | ||
|
|
20090f3172 | ||
|
|
69a2d4ccff | ||
|
|
e8b954c646 | ||
|
|
c57317cbf0 | ||
|
|
51b2fd299c | ||
|
|
d0634b1596 | ||
|
|
43606d6d6a | ||
|
|
70b1010fa5 | ||
|
|
84e5721f3a | ||
|
|
319fb1ce03 | ||
|
|
b255445557 | ||
|
|
f02f83660c | ||
|
|
b23424bb3c | ||
|
|
5fd6988126 | ||
|
|
5b82960df8 | ||
|
|
cc9a252d8c | ||
|
|
d281a6e603 | ||
|
|
154f6f45d4 | ||
|
|
0d41623b52 | ||
|
|
c279f96371 | ||
|
|
ebc529cbb3 | ||
|
|
73e2c8f68f | ||
|
|
f4408219e9 |
10
.github/workflows/release.yaml
vendored
10
.github/workflows/release.yaml
vendored
@@ -31,7 +31,7 @@ jobs:
|
|||||||
security set-keychain-settings -lut 3600 build.keychain
|
security set-keychain-settings -lut 3600 build.keychain
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
go-version: "stable"
|
||||||
cache: true
|
cache: true
|
||||||
- name: Build Darwin
|
- name: Build Darwin
|
||||||
env:
|
env:
|
||||||
@@ -87,7 +87,7 @@ jobs:
|
|||||||
write-host "plugin installed"
|
write-host "plugin installed"
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
go-version: "stable"
|
||||||
cache: true
|
cache: true
|
||||||
- run: go get ./...
|
- run: go get ./...
|
||||||
- run: |
|
- run: |
|
||||||
@@ -141,7 +141,7 @@ jobs:
|
|||||||
write-host "plugin installed"
|
write-host "plugin installed"
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
go-version: "stable"
|
||||||
cache: true
|
cache: true
|
||||||
- name: 'Install ROCm'
|
- name: 'Install ROCm'
|
||||||
run: |
|
run: |
|
||||||
@@ -218,7 +218,7 @@ jobs:
|
|||||||
write-host "plugin installed"
|
write-host "plugin installed"
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
go-version: "stable"
|
||||||
cache: true
|
cache: true
|
||||||
- name: 'Install CUDA'
|
- name: 'Install CUDA'
|
||||||
run: |
|
run: |
|
||||||
@@ -306,7 +306,7 @@ jobs:
|
|||||||
write-host "plugin installed"
|
write-host "plugin installed"
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
go-version: "stable"
|
||||||
cache: true
|
cache: true
|
||||||
- run: go get
|
- run: go get
|
||||||
- uses: actions/download-artifact@v4
|
- uses: actions/download-artifact@v4
|
||||||
|
|||||||
10
.github/workflows/test.yaml
vendored
10
.github/workflows/test.yaml
vendored
@@ -63,7 +63,7 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
go-version: "stable"
|
||||||
cache: true
|
cache: true
|
||||||
- run: go get ./...
|
- run: go get ./...
|
||||||
- run: |
|
- run: |
|
||||||
@@ -163,7 +163,7 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
go-version: "stable"
|
||||||
cache: true
|
cache: true
|
||||||
- name: 'Install ROCm'
|
- name: 'Install ROCm'
|
||||||
run: |
|
run: |
|
||||||
@@ -200,7 +200,7 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
go-version: "stable"
|
||||||
cache: true
|
cache: true
|
||||||
- name: 'Install CUDA'
|
- name: 'Install CUDA'
|
||||||
run: |
|
run: |
|
||||||
@@ -255,7 +255,7 @@ jobs:
|
|||||||
submodules: recursive
|
submodules: recursive
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
go-version: "stable"
|
||||||
cache: false
|
cache: false
|
||||||
- run: |
|
- run: |
|
||||||
case ${{ matrix.arch }} in
|
case ${{ matrix.arch }} in
|
||||||
@@ -297,7 +297,7 @@ jobs:
|
|||||||
submodules: recursive
|
submodules: recursive
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
go-version: "stable"
|
||||||
cache: true
|
cache: true
|
||||||
- run: |
|
- run: |
|
||||||
case ${{ matrix.arch }} in
|
case ${{ matrix.arch }} in
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
ARG GOLANG_VERSION=1.22.1
|
ARG GOLANG_VERSION=1.22.5
|
||||||
ARG CMAKE_VERSION=3.22.1
|
ARG CMAKE_VERSION=3.22.1
|
||||||
# this CUDA_VERSION corresponds with the one specified in docs/gpu.md
|
# this CUDA_VERSION corresponds with the one specified in docs/gpu.md
|
||||||
ARG CUDA_VERSION=11.3.1
|
ARG CUDA_VERSION=11.3.1
|
||||||
|
|||||||
@@ -64,7 +64,8 @@ Here are some example models that can be downloaded:
|
|||||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||||
| Solar | 10.7B | 6.1GB | `ollama run solar` |
|
| Solar | 10.7B | 6.1GB | `ollama run solar` |
|
||||||
|
|
||||||
> Note: You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
> [!NOTE]
|
||||||
|
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||||
|
|
||||||
## Customize a model
|
## Customize a model
|
||||||
|
|
||||||
@@ -295,6 +296,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
|
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
|
||||||
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
|
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
|
||||||
- [AI Studio](https://github.com/MindWorkAI/AI-Studio)
|
- [AI Studio](https://github.com/MindWorkAI/AI-Studio)
|
||||||
|
- [Sidellama](https://github.com/gyopak/sidellama) (browser-based LLM client)
|
||||||
|
- [LLMStack](https://github.com/trypromptly/LLMStack) (No-code multi-agent framework to build LLM agents and workflows)
|
||||||
|
|
||||||
### Terminal
|
### Terminal
|
||||||
|
|
||||||
|
|||||||
78
api/types.go
78
api/types.go
@@ -101,46 +101,29 @@ type ChatRequest struct {
|
|||||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
|
|
||||||
// Tools is an optional list of tools the model has access to.
|
// Tools is an optional list of tools the model has access to.
|
||||||
Tools []Tool `json:"tools,omitempty"`
|
Tools `json:"tools,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Tools []Tool
|
||||||
|
|
||||||
|
func (t Tools) String() string {
|
||||||
|
bts, _ := json.Marshal(t)
|
||||||
|
return string(bts)
|
||||||
|
}
|
||||||
|
|
||||||
// Message is a single message in a chat sequence. The message contains the
|
// Message is a single message in a chat sequence. The message contains the
|
||||||
// role ("system", "user", or "assistant"), the content and an optional list
|
// role ("system", "user", or "assistant"), the content and an optional list
|
||||||
// of images.
|
// of images.
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content,omitempty"`
|
Content string `json:"content"`
|
||||||
Images []ImageData `json:"images,omitempty"`
|
Images []ImageData `json:"images,omitempty"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolCall struct {
|
|
||||||
Function struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Arguments map[string]any `json:"arguments"`
|
|
||||||
} `json:"function"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Tool struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Function struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Parameters struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Required []string `json:"required"`
|
|
||||||
Properties map[string]struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Enum []string `json:"enum,omitempty"`
|
|
||||||
} `json:"properties"`
|
|
||||||
} `json:"parameters"`
|
|
||||||
} `json:"function"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Message) UnmarshalJSON(b []byte) error {
|
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||||
type Alias Message
|
type Alias Message
|
||||||
var a Alias
|
var a Alias
|
||||||
@@ -153,6 +136,46 @@ func (m *Message) UnmarshalJSON(b []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
Function ToolCallFunction `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCallFunction struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments ToolCallFunctionArguments `json:"arguments"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCallFunctionArguments map[string]any
|
||||||
|
|
||||||
|
func (t *ToolCallFunctionArguments) String() string {
|
||||||
|
bts, _ := json.Marshal(t)
|
||||||
|
return string(bts)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function ToolFunction `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolFunction struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Parameters struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Required []string `json:"required"`
|
||||||
|
Properties map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
} `json:"properties"`
|
||||||
|
} `json:"parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ToolFunction) String() string {
|
||||||
|
bts, _ := json.Marshal(t)
|
||||||
|
return string(bts)
|
||||||
|
}
|
||||||
|
|
||||||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||||
// similar to [GenerateResponse].
|
// similar to [GenerateResponse].
|
||||||
type ChatResponse struct {
|
type ChatResponse struct {
|
||||||
@@ -405,9 +428,6 @@ type GenerateResponse struct {
|
|||||||
// Response is the textual response itself.
|
// Response is the textual response itself.
|
||||||
Response string `json:"response"`
|
Response string `json:"response"`
|
||||||
|
|
||||||
// ToolCalls is the list of tools the model wants to call
|
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
|
||||||
|
|
||||||
// Done specifies if the response is complete.
|
// Done specifies if the response is complete.
|
||||||
Done bool `json:"done"`
|
Done bool `json:"done"`
|
||||||
|
|
||||||
|
|||||||
@@ -1344,7 +1344,6 @@ func NewCLI() *cobra.Command {
|
|||||||
envVars["OLLAMA_TMPDIR"],
|
envVars["OLLAMA_TMPDIR"],
|
||||||
envVars["OLLAMA_FLASH_ATTENTION"],
|
envVars["OLLAMA_FLASH_ATTENTION"],
|
||||||
envVars["OLLAMA_LLM_LIBRARY"],
|
envVars["OLLAMA_LLM_LIBRARY"],
|
||||||
envVars["OLLAMA_MAX_VRAM"],
|
|
||||||
})
|
})
|
||||||
default:
|
default:
|
||||||
appendEnvDocs(cmd, envs)
|
appendEnvDocs(cmd, envs)
|
||||||
|
|||||||
@@ -71,6 +71,11 @@ func (m *MistralModel) WriteGGUF(ws io.WriteSeeker) error {
|
|||||||
"tokenizer.ggml.unknown_token_id": uint32(0),
|
"tokenizer.ggml.unknown_token_id": uint32(0),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.Params.HeadDimension > 0 {
|
||||||
|
kv["llama.attention.key_length"] = uint32(m.Params.HeadDimension)
|
||||||
|
kv["llama.attention.value_length"] = uint32(m.Params.HeadDimension)
|
||||||
|
}
|
||||||
|
|
||||||
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
|
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
201
docs/api.md
201
docs/api.md
@@ -40,6 +40,7 @@ Generate a response for a given prompt with a provided model. This is a streamin
|
|||||||
|
|
||||||
- `model`: (required) the [model name](#model-names)
|
- `model`: (required) the [model name](#model-names)
|
||||||
- `prompt`: the prompt to generate a response for
|
- `prompt`: the prompt to generate a response for
|
||||||
|
- `suffix`: the text after the model response
|
||||||
- `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`)
|
- `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`)
|
||||||
|
|
||||||
Advanced parameters (optional):
|
Advanced parameters (optional):
|
||||||
@@ -57,7 +58,8 @@ Advanced parameters (optional):
|
|||||||
|
|
||||||
Enable JSON mode by setting the `format` parameter to `json`. This will structure the response as a valid JSON object. See the JSON mode [example](#request-json-mode) below.
|
Enable JSON mode by setting the `format` parameter to `json`. This will structure the response as a valid JSON object. See the JSON mode [example](#request-json-mode) below.
|
||||||
|
|
||||||
> Note: it's important to instruct the model to use JSON in the `prompt`. Otherwise, the model may generate large amounts whitespace.
|
> [!IMPORTANT]
|
||||||
|
> It's important to instruct the model to use JSON in the `prompt`. Otherwise, the model may generate large amounts whitespace.
|
||||||
|
|
||||||
### Examples
|
### Examples
|
||||||
|
|
||||||
@@ -148,8 +150,44 @@ If `stream` is set to `false`, the response will be a single JSON object:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Request (with suffix)
|
||||||
|
|
||||||
|
##### Request
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl http://localhost:11434/api/generate -d '{
|
||||||
|
"model": "codellama:code",
|
||||||
|
"prompt": "def compute_gcd(a, b):",
|
||||||
|
"suffix": " return result",
|
||||||
|
"options": {
|
||||||
|
"temperature": 0
|
||||||
|
},
|
||||||
|
"stream": false
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "codellama:code",
|
||||||
|
"created_at": "2024-07-22T20:47:51.147561Z",
|
||||||
|
"response": "\n if a == 0:\n return b\n else:\n return compute_gcd(b % a, a)\n\ndef compute_lcm(a, b):\n result = (a * b) / compute_gcd(a, b)\n",
|
||||||
|
"done": true,
|
||||||
|
"done_reason": "stop",
|
||||||
|
"context": [...],
|
||||||
|
"total_duration": 1162761250,
|
||||||
|
"load_duration": 6683708,
|
||||||
|
"prompt_eval_count": 17,
|
||||||
|
"prompt_eval_duration": 201222000,
|
||||||
|
"eval_count": 63,
|
||||||
|
"eval_duration": 953997000
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
#### Request (JSON mode)
|
#### Request (JSON mode)
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
> When `format` is set to `json`, the output will always be a well-formed JSON object. It's important to also instruct the model to respond in JSON.
|
> When `format` is set to `json`, the output will always be a well-formed JSON object. It's important to also instruct the model to respond in JSON.
|
||||||
|
|
||||||
##### Request
|
##### Request
|
||||||
@@ -380,12 +418,14 @@ Generate the next message in a chat with a provided model. This is a streaming e
|
|||||||
|
|
||||||
- `model`: (required) the [model name](#model-names)
|
- `model`: (required) the [model name](#model-names)
|
||||||
- `messages`: the messages of the chat, this can be used to keep a chat memory
|
- `messages`: the messages of the chat, this can be used to keep a chat memory
|
||||||
|
- `tools`: tools for the model to use if supported. Requires `stream` to be set to `false`
|
||||||
|
|
||||||
The `message` object has the following fields:
|
The `message` object has the following fields:
|
||||||
|
|
||||||
- `role`: the role of the message, either `system`, `user` or `assistant`
|
- `role`: the role of the message, either `system`, `user`, `assistant`, or `tool`
|
||||||
- `content`: the content of the message
|
- `content`: the content of the message
|
||||||
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
|
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
|
||||||
|
- `tool_calls` (optional): a list of tools the model wants to use
|
||||||
|
|
||||||
Advanced parameters (optional):
|
Advanced parameters (optional):
|
||||||
|
|
||||||
@@ -622,6 +662,79 @@ curl http://localhost:11434/api/chat -d '{
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Chat request (with tools)
|
||||||
|
|
||||||
|
##### Request
|
||||||
|
|
||||||
|
```
|
||||||
|
curl http://localhost:11434/api/chat -d '{
|
||||||
|
"model": "mistral",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the weather today in Paris?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": false,
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather for a location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The location to get the weather for, e.g. San Francisco, CA"
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
|
||||||
|
"enum": ["celsius", "fahrenheit"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["location", "format"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "mistral:7b-instruct-v0.3-q4_K_M",
|
||||||
|
"created_at": "2024-07-22T20:33:28.123648Z",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"arguments": {
|
||||||
|
"format": "celsius",
|
||||||
|
"location": "Paris, FR"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"done_reason": "stop",
|
||||||
|
"done": true,
|
||||||
|
"total_duration": 885095291,
|
||||||
|
"load_duration": 3753500,
|
||||||
|
"prompt_eval_count": 122,
|
||||||
|
"prompt_eval_duration": 328493000,
|
||||||
|
"eval_count": 33,
|
||||||
|
"eval_duration": 552222000
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Create a Model
|
## Create a Model
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
@@ -1026,7 +1139,7 @@ If `stream` is set to `false`, then the response is a single JSON object:
|
|||||||
## Generate Embeddings
|
## Generate Embeddings
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
POST /api/embeddings
|
POST /api/embed
|
||||||
```
|
```
|
||||||
|
|
||||||
Generate embeddings from a model
|
Generate embeddings from a model
|
||||||
@@ -1034,10 +1147,11 @@ Generate embeddings from a model
|
|||||||
### Parameters
|
### Parameters
|
||||||
|
|
||||||
- `model`: name of model to generate embeddings from
|
- `model`: name of model to generate embeddings from
|
||||||
- `prompt`: text to generate embeddings for
|
- `input`: text or list of text to generate embeddings for
|
||||||
|
|
||||||
Advanced parameters:
|
Advanced parameters:
|
||||||
|
|
||||||
|
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
|
||||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||||
|
|
||||||
@@ -1046,9 +1160,9 @@ Advanced parameters:
|
|||||||
#### Request
|
#### Request
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/embeddings -d '{
|
curl http://localhost:11434/api/embed -d '{
|
||||||
"model": "all-minilm",
|
"model": "all-minilm",
|
||||||
"prompt": "Here is an article about llamas..."
|
"input": "Why is the sky blue?"
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1056,10 +1170,35 @@ curl http://localhost:11434/api/embeddings -d '{
|
|||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"embedding": [
|
"model": "all-minilm",
|
||||||
0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313,
|
"embeddings": [[
|
||||||
0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281
|
0.010071029, -0.0017594862, 0.05007221, 0.04692972, 0.054916814,
|
||||||
]
|
0.008599704, 0.105441414, -0.025878139, 0.12958129, 0.031952348
|
||||||
|
]]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Request (Multiple input)
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl http://localhost:11434/api/embed -d '{
|
||||||
|
"model": "all-minilm",
|
||||||
|
"input": ["Why is the sky blue?", "Why is the grass green?"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "all-minilm",
|
||||||
|
"embeddings": [[
|
||||||
|
0.010071029, -0.0017594862, 0.05007221, 0.04692972, 0.054916814,
|
||||||
|
0.008599704, 0.105441414, -0.025878139, 0.12958129, 0.031952348
|
||||||
|
],[
|
||||||
|
-0.0098027075, 0.06042469, 0.025257962, -0.006364387, 0.07272725,
|
||||||
|
0.017194884, 0.09032035, -0.051705178, 0.09951512, 0.09072481
|
||||||
|
]]
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1106,3 +1245,45 @@ A single JSON object will be returned.
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Generate Embedding
|
||||||
|
|
||||||
|
> Note: this endpoint has been superseded by `/api/embed`
|
||||||
|
|
||||||
|
```shell
|
||||||
|
POST /api/embeddings
|
||||||
|
```
|
||||||
|
|
||||||
|
Generate embeddings from a model
|
||||||
|
|
||||||
|
### Parameters
|
||||||
|
|
||||||
|
- `model`: name of model to generate embeddings from
|
||||||
|
- `prompt`: text to generate embeddings for
|
||||||
|
|
||||||
|
Advanced parameters:
|
||||||
|
|
||||||
|
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||||
|
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
#### Request
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl http://localhost:11434/api/embeddings -d '{
|
||||||
|
"model": "all-minilm",
|
||||||
|
"prompt": "Here is an article about llamas..."
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"embedding": [
|
||||||
|
0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313,
|
||||||
|
0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|||||||
15
docs/gpu.md
15
docs/gpu.md
@@ -46,13 +46,24 @@ sudo modprobe nvidia_uvm`
|
|||||||
|
|
||||||
## AMD Radeon
|
## AMD Radeon
|
||||||
Ollama supports the following AMD GPUs:
|
Ollama supports the following AMD GPUs:
|
||||||
|
|
||||||
|
### Linux Support
|
||||||
| Family | Cards and accelerators |
|
| Family | Cards and accelerators |
|
||||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` |
|
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` |
|
||||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` |
|
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` |
|
||||||
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` |
|
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` |
|
||||||
|
|
||||||
### Overrides
|
### Windows Support
|
||||||
|
With ROCm v6.1, the following GPUs are supported on Windows.
|
||||||
|
|
||||||
|
| Family | Cards and accelerators |
|
||||||
|
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` |
|
||||||
|
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
|
||||||
|
|
||||||
|
|
||||||
|
### Overrides on Linux
|
||||||
Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In
|
Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In
|
||||||
some cases you can force the system to try to use a similar LLVM target that is
|
some cases you can force the system to try to use a similar LLVM target that is
|
||||||
close. For example The Radeon RX 5400 is `gfx1034` (also known as 10.3.4)
|
close. For example The Radeon RX 5400 is `gfx1034` (also known as 10.3.4)
|
||||||
@@ -63,7 +74,7 @@ would set `HSA_OVERRIDE_GFX_VERSION="10.3.0"` as an environment variable for the
|
|||||||
server. If you have an unsupported AMD GPU you can experiment using the list of
|
server. If you have an unsupported AMD GPU you can experiment using the list of
|
||||||
supported types below.
|
supported types below.
|
||||||
|
|
||||||
At this time, the known supported GPU types are the following LLVM Targets.
|
At this time, the known supported GPU types on linux are the following LLVM Targets.
|
||||||
This table shows some example GPUs that map to these LLVM targets:
|
This table shows some example GPUs that map to these LLVM targets:
|
||||||
| **LLVM Target** | **An Example GPU** |
|
| **LLVM Target** | **An Example GPU** |
|
||||||
|-----------------|---------------------|
|
|-----------------|---------------------|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# Ollama Model File
|
# Ollama Model File
|
||||||
|
|
||||||
> Note: `Modelfile` syntax is in development
|
> [!NOTE]
|
||||||
|
> `Modelfile` syntax is in development
|
||||||
|
|
||||||
A model file is the blueprint to create and share models with Ollama.
|
A model file is the blueprint to create and share models with Ollama.
|
||||||
|
|
||||||
|
|||||||
@@ -78,8 +78,8 @@ curl http://localhost:11434/v1/chat/completions \
|
|||||||
- [x] Streaming
|
- [x] Streaming
|
||||||
- [x] JSON mode
|
- [x] JSON mode
|
||||||
- [x] Reproducible outputs
|
- [x] Reproducible outputs
|
||||||
|
- [x] Tools (streaming support coming soon)
|
||||||
- [ ] Vision
|
- [ ] Vision
|
||||||
- [ ] Function calling
|
|
||||||
- [ ] Logprobs
|
- [ ] Logprobs
|
||||||
|
|
||||||
#### Supported request fields
|
#### Supported request fields
|
||||||
@@ -97,16 +97,12 @@ curl http://localhost:11434/v1/chat/completions \
|
|||||||
- [x] `temperature`
|
- [x] `temperature`
|
||||||
- [x] `top_p`
|
- [x] `top_p`
|
||||||
- [x] `max_tokens`
|
- [x] `max_tokens`
|
||||||
- [ ] `logit_bias`
|
- [x] `tools`
|
||||||
- [ ] `tools`
|
|
||||||
- [ ] `tool_choice`
|
- [ ] `tool_choice`
|
||||||
|
- [ ] `logit_bias`
|
||||||
- [ ] `user`
|
- [ ] `user`
|
||||||
- [ ] `n`
|
- [ ] `n`
|
||||||
|
|
||||||
#### Notes
|
|
||||||
|
|
||||||
- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
|
|
||||||
|
|
||||||
## Models
|
## Models
|
||||||
|
|
||||||
Before using a model, pull it locally `ollama pull`:
|
Before using a model, pull it locally `ollama pull`:
|
||||||
|
|||||||
173
docs/template.md
Normal file
173
docs/template.md
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
# Template
|
||||||
|
|
||||||
|
Ollama provides a powerful templating engine backed by Go's built-in templating engine to construct prompts for your large language model. This feature is a valuable tool to get the most out of your models.
|
||||||
|
|
||||||
|
## Basic Template Structure
|
||||||
|
|
||||||
|
A basic Go template consists of three main parts:
|
||||||
|
|
||||||
|
* **Layout**: The overall structure of the template.
|
||||||
|
* **Variables**: Placeholders for dynamic data that will be replaced with actual values when the template is rendered.
|
||||||
|
* **Functions**: Custom functions or logic that can be used to manipulate the template's content.
|
||||||
|
|
||||||
|
Here's an example of a simple chat template:
|
||||||
|
|
||||||
|
```gotmpl
|
||||||
|
{{- range .Messages }}
|
||||||
|
{{ .Role }}: {{ .Content }}
|
||||||
|
{{- end }}
|
||||||
|
```
|
||||||
|
|
||||||
|
In this example, we have:
|
||||||
|
|
||||||
|
* A basic messages structure (layout)
|
||||||
|
* Three variables: `Messages`, `Role`, and `Content` (variables)
|
||||||
|
* A custom function (action) that iterates over an array of items (`range .Messages`) and displays each item
|
||||||
|
|
||||||
|
## Adding templates to your model
|
||||||
|
|
||||||
|
By default, models imported into Ollama have a default template of `{{ .Prompt }}`, i.e. user inputs are sent verbatim to the LLM. This is appropriate for text or code completion models but lacks essential markers for chat or instruction models.
|
||||||
|
|
||||||
|
Omitting a template in these models puts the responsibility of correctly templating input onto the user. Adding a template allows users to easily get the best results from the model.
|
||||||
|
|
||||||
|
To add templates in your model, you'll need to add a `TEMPLATE` command to the Modelfile. Here's an example using Meta's Llama 3.
|
||||||
|
|
||||||
|
```dockerfile
|
||||||
|
FROM llama3
|
||||||
|
|
||||||
|
TEMPLATE """{{- if .System }}<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .System }}<|eot_id|>
|
||||||
|
{{- end }}
|
||||||
|
{{- range .Messages }}<|start_header_id|>{{ .Role }}<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .Content }}<|eot_id|>
|
||||||
|
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
## Variables
|
||||||
|
|
||||||
|
`System` (string): system prompt
|
||||||
|
|
||||||
|
`Prompt` (string): user prompt
|
||||||
|
|
||||||
|
`Response` (string): assistant response
|
||||||
|
|
||||||
|
`Suffix` (string): text inserted after the assistant's response
|
||||||
|
|
||||||
|
`Messages` (list): list of messages
|
||||||
|
|
||||||
|
`Messages[].Role` (string): role which can be one of `system`, `user`, `assistant`, or `tool`
|
||||||
|
|
||||||
|
`Messages[].Content` (string): message content
|
||||||
|
|
||||||
|
`Messages[].ToolCalls` (list): list of tools the model wants to call
|
||||||
|
|
||||||
|
`Messages[].ToolCalls[].Function` (object): function to call
|
||||||
|
|
||||||
|
`Messages[].ToolCalls[].Function.Name` (string): function name
|
||||||
|
|
||||||
|
`Messages[].ToolCalls[].Function.Arguments` (map): mapping of argument name to argument value
|
||||||
|
|
||||||
|
`Tools` (list): list of tools the model can access
|
||||||
|
|
||||||
|
`Tools[].Type` (string): schema type. `type` is always `function`
|
||||||
|
|
||||||
|
`Tools[].Function` (object): function definition
|
||||||
|
|
||||||
|
`Tools[].Function.Name` (string): function name
|
||||||
|
|
||||||
|
`Tools[].Function.Description` (string): function description
|
||||||
|
|
||||||
|
`Tools[].Function.Parameters` (object): function parameters
|
||||||
|
|
||||||
|
`Tools[].Function.Parameters.Type` (string): schema type. `type` is always `object`
|
||||||
|
|
||||||
|
`Tools[].Function.Parameters.Required` (list): list of required properties
|
||||||
|
|
||||||
|
`Tools[].Function.Parameters.Properties` (map): mapping of property name to property definition
|
||||||
|
|
||||||
|
`Tools[].Function.Parameters.Properties[].Type` (string): property type
|
||||||
|
|
||||||
|
`Tools[].Function.Parameters.Properties[].Description` (string): property description
|
||||||
|
|
||||||
|
`Tools[].Function.Parameters.Properties[].Enum` (list): list of valid values
|
||||||
|
|
||||||
|
## Tips and Best Practices
|
||||||
|
|
||||||
|
Keep the following tips and best practices in mind when working with Go templates:
|
||||||
|
|
||||||
|
* **Be mindful of dot**: Control flow structures like `range` and `with` changes the value `.`
|
||||||
|
* **Out-of-scope variables**: Use `$.` to reference variables not currently in scope, starting from the root
|
||||||
|
* **Whitespace control**: Use `-` to trim leading (`{{-`) and trailing (`-}}`) whitespace
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
### Example Messages
|
||||||
|
|
||||||
|
#### ChatML
|
||||||
|
|
||||||
|
ChatML is a popular template format. It can be used for models such as Databrick's DBRX, Intel's Neural Chat, and Microsoft's Orca 2.
|
||||||
|
|
||||||
|
```gotmpl
|
||||||
|
{{- if .System }}<|im_start|>system
|
||||||
|
{{ .System }}<|im_end|>
|
||||||
|
{{ end }}
|
||||||
|
{{- range .Messages }}<|im_start|>{{ .Role }}
|
||||||
|
{{ .Content }}<|im_end|>
|
||||||
|
{{ end }}<|im_start|>assistant
|
||||||
|
{{ else }}
|
||||||
|
{{ if .System }}<|im_start|>system
|
||||||
|
{{ .System }}<|im_end|>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example Tools
|
||||||
|
|
||||||
|
Tools support can be added to a model by adding a `{{ .Tools }}` node to the template. This feature is useful for models trained to call external tools and can a powerful tool for retrieving real-time data or performing complex tasks.
|
||||||
|
|
||||||
|
#### Mistral
|
||||||
|
|
||||||
|
Mistral v0.3 and Mixtral 8x22B supports tool calling.
|
||||||
|
|
||||||
|
```gotmpl
|
||||||
|
{{- range $index, $_ := .Messages }}
|
||||||
|
{{- if eq .Role "user" }}
|
||||||
|
{{- if and (le (len (slice $.Messages $index)) 2) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
|
||||||
|
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
|
||||||
|
|
||||||
|
{{ end }}{{ .Content }}[/INST]
|
||||||
|
{{- else if eq .Role "assistant" }}
|
||||||
|
{{- if .Content }} {{ .Content }}</s>
|
||||||
|
{{- else if .ToolCalls }}[TOOL_CALLS] [
|
||||||
|
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
|
||||||
|
{{- end }}]</s>
|
||||||
|
{{- end }}
|
||||||
|
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example Fill-in-Middle
|
||||||
|
|
||||||
|
Fill-in-middle support can be added to a model by adding a `{{ .Suffix }}` node to the template. This feature is useful for models that are trained to generate text in the middle of user input, such as code completion models.
|
||||||
|
|
||||||
|
#### CodeLlama
|
||||||
|
|
||||||
|
CodeLlama [7B](https://ollama.com/library/codellama:7b-code) and [13B](https://ollama.com/library/codellama:13b-code) code completion models support fill-in-middle.
|
||||||
|
|
||||||
|
```gotmpl
|
||||||
|
<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
|
||||||
|
```
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> CodeLlama 34B and 70B code completion and all instruct and Python fine-tuned models do not support fill-in-middle.
|
||||||
|
|
||||||
|
#### Codestral
|
||||||
|
|
||||||
|
Codestral [22B](https://ollama.com/library/codestral:22b) supports fill-in-middle.
|
||||||
|
|
||||||
|
```gotmpl
|
||||||
|
[SUFFIX]{{ .Suffix }}[PREFIX] {{ .Prompt }}
|
||||||
|
```
|
||||||
@@ -43,8 +43,6 @@ var (
|
|||||||
MaxRunners int
|
MaxRunners int
|
||||||
// Set via OLLAMA_MAX_QUEUE in the environment
|
// Set via OLLAMA_MAX_QUEUE in the environment
|
||||||
MaxQueuedRequests int
|
MaxQueuedRequests int
|
||||||
// Set via OLLAMA_MAX_VRAM in the environment
|
|
||||||
MaxVRAM uint64
|
|
||||||
// Set via OLLAMA_MODELS in the environment
|
// Set via OLLAMA_MODELS in the environment
|
||||||
ModelsDir string
|
ModelsDir string
|
||||||
// Set via OLLAMA_NOHISTORY in the environment
|
// Set via OLLAMA_NOHISTORY in the environment
|
||||||
@@ -89,7 +87,6 @@ func AsMap() map[string]EnvVar {
|
|||||||
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
|
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
|
||||||
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"},
|
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"},
|
||||||
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
|
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
|
||||||
"OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"},
|
|
||||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"},
|
"OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"},
|
||||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
|
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
|
||||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
|
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
|
||||||
@@ -194,16 +191,6 @@ func LoadConfig() {
|
|||||||
|
|
||||||
TmpDir = clean("OLLAMA_TMPDIR")
|
TmpDir = clean("OLLAMA_TMPDIR")
|
||||||
|
|
||||||
userLimit := clean("OLLAMA_MAX_VRAM")
|
|
||||||
if userLimit != "" {
|
|
||||||
avail, err := strconv.ParseUint(userLimit, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("invalid setting, ignoring", "OLLAMA_MAX_VRAM", userLimit, "error", err)
|
|
||||||
} else {
|
|
||||||
MaxVRAM = avail
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
LLMLibrary = clean("OLLAMA_LLM_LIBRARY")
|
LLMLibrary = clean("OLLAMA_LLM_LIBRARY")
|
||||||
|
|
||||||
if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" {
|
if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" {
|
||||||
|
|||||||
@@ -33,9 +33,10 @@ type HipLib struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewHipLib() (*HipLib, error) {
|
func NewHipLib() (*HipLib, error) {
|
||||||
h, err := windows.LoadLibrary("amdhip64.dll")
|
// At runtime we depend on v6, so discover GPUs with the same library for a consistent set of GPUs
|
||||||
|
h, err := windows.LoadLibrary("amdhip64_6.dll")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to load amdhip64.dll: %w", err)
|
return nil, fmt.Errorf("unable to load amdhip64_6.dll, please make sure to upgrade to the latest amd driver: %w", err)
|
||||||
}
|
}
|
||||||
hl := &HipLib{}
|
hl := &HipLib{}
|
||||||
hl.dll = h
|
hl.dll = h
|
||||||
|
|||||||
@@ -92,7 +92,8 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if gfxOverride == "" {
|
if gfxOverride == "" {
|
||||||
if !slices.Contains[[]string, string](supported, gfx) {
|
// Strip off Target Features when comparing
|
||||||
|
if !slices.Contains[[]string, string](supported, strings.Split(gfx, ":")[0]) {
|
||||||
slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
|
slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
|
||||||
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
||||||
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
|
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
|||||||
reqLimit := len(req)
|
reqLimit := len(req)
|
||||||
iterLimit := 5
|
iterLimit := 5
|
||||||
|
|
||||||
vram := os.Getenv("OLLAMA_MAX_VRAM")
|
vram := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM
|
||||||
if vram != "" {
|
if vram != "" {
|
||||||
max, err := strconv.ParseUint(vram, 10, 64)
|
max, err := strconv.ParseUint(vram, 10, 64)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -106,7 +106,7 @@ func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
|||||||
|
|
||||||
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
|
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
|
||||||
func TestMultiModelStress(t *testing.T) {
|
func TestMultiModelStress(t *testing.T) {
|
||||||
vram := os.Getenv("OLLAMA_MAX_VRAM")
|
vram := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM
|
||||||
if vram == "" {
|
if vram == "" {
|
||||||
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
|
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
func TestContextExhaustion(t *testing.T) {
|
func TestContextExhaustion(t *testing.T) {
|
||||||
// Longer needed for small footprint GPUs
|
// Longer needed for small footprint GPUs
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 6*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
// Set up the test data
|
// Set up the test data
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
@@ -25,5 +25,10 @@ func TestContextExhaustion(t *testing.T) {
|
|||||||
"num_ctx": 128,
|
"num_ctx": 128,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatalf("PullIfMissing failed: %v", err)
|
||||||
|
}
|
||||||
|
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,12 +4,45 @@ package integration
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"math"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func floatsEqual32(a, b float32) bool {
|
||||||
|
return math.Abs(float64(a-b)) <= 1e-4
|
||||||
|
}
|
||||||
|
|
||||||
|
func floatsEqual64(a, b float64) bool {
|
||||||
|
return math.Abs(a-b) <= 1e-4
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req := api.EmbeddingRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Prompt: "why is the sky blue?",
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := embeddingTestHelper(ctx, t, req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embedding) != 384 {
|
||||||
|
t.Fatalf("expected 384 floats, got %d", len(res.Embedding))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !floatsEqual64(res.Embedding[0], 0.06642947345972061) {
|
||||||
|
t.Fatalf("expected 0.06642947345972061, got %.16f", res.Embedding[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAllMiniLMEmbed(t *testing.T) {
|
func TestAllMiniLMEmbed(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -33,8 +66,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
|||||||
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.Embeddings[0][0] != 0.010071031 {
|
if !floatsEqual32(res.Embeddings[0][0], 0.010071031) {
|
||||||
t.Fatalf("expected 0.010071031, got %f", res.Embeddings[0][0])
|
t.Fatalf("expected 0.010071031, got %.8f", res.Embeddings[0][0])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -61,12 +94,12 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
|||||||
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.Embeddings[0][0] != 0.010071031 || res.Embeddings[1][0] != -0.009802706 {
|
if !floatsEqual32(res.Embeddings[0][0], 0.010071031) || !floatsEqual32(res.Embeddings[1][0], -0.009802706) {
|
||||||
t.Fatalf("expected 0.010071031 and -0.009802706, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0])
|
t.Fatalf("expected 0.010071031 and -0.009802706, got %.8f and %.8f", res.Embeddings[0][0], res.Embeddings[1][0])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAllMiniLmEmbedTruncate(t *testing.T) {
|
func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -135,6 +168,22 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := client.Embeddings(ctx, &req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ function amdGPUs {
|
|||||||
return $env:AMDGPU_TARGETS
|
return $env:AMDGPU_TARGETS
|
||||||
}
|
}
|
||||||
# Current supported rocblas list from ROCm v6.1.2 on windows
|
# Current supported rocblas list from ROCm v6.1.2 on windows
|
||||||
|
# https://rocm.docs.amd.com/projects/install-on-windows/en/latest/reference/system-requirements.html#windows-supported-gpus
|
||||||
$GPU_LIST = @(
|
$GPU_LIST = @(
|
||||||
"gfx906:xnack-"
|
|
||||||
"gfx1030"
|
"gfx1030"
|
||||||
"gfx1100"
|
"gfx1100"
|
||||||
"gfx1101"
|
"gfx1101"
|
||||||
|
|||||||
Submodule llm/llama.cpp updated: a8db2a9ce6...d94c6e0ccb
@@ -1,8 +1,8 @@
|
|||||||
diff --git a/src/llama.cpp b/src/llama.cpp
|
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||||
index 2b9ace28..172640e2 100644
|
index 8fe51971..7113ba64 100644
|
||||||
--- a/src/llama.cpp
|
--- a/src/llama.cpp
|
||||||
+++ b/src/llama.cpp
|
+++ b/src/llama.cpp
|
||||||
@@ -5357,16 +5357,7 @@ static void llm_load_vocab(
|
@@ -5433,16 +5433,7 @@ static void llm_load_vocab(
|
||||||
if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
|
if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
|
||||||
vocab.tokenizer_add_space_prefix = false;
|
vocab.tokenizer_add_space_prefix = false;
|
||||||
vocab.tokenizer_clean_spaces = true;
|
vocab.tokenizer_clean_spaces = true;
|
||||||
@@ -20,9 +20,9 @@ index 2b9ace28..172640e2 100644
|
|||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
} else if (
|
} else if (
|
||||||
tokenizer_pre == "llama3" ||
|
tokenizer_pre == "llama3" ||
|
||||||
@@ -5439,7 +5430,8 @@ static void llm_load_vocab(
|
@@ -5526,7 +5517,8 @@ static void llm_load_vocab(
|
||||||
tokenizer_pre == "jais") {
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMOLLM;
|
||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
|
vocab.tokenizer_clean_spaces = false;
|
||||||
} else {
|
} else {
|
||||||
- throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
- throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||||
+ LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
|
+ LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
diff --git a/src/llama.cpp b/src/llama.cpp
|
|
||||||
index 40d2ec2c..f34eb79a 100644
|
|
||||||
--- a/src/llama.cpp
|
|
||||||
+++ b/src/llama.cpp
|
|
||||||
@@ -6943,7 +6943,7 @@ static struct ggml_tensor * llm_build_kqv(
|
|
||||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
|
||||||
cb(kq, "kq", il);
|
|
||||||
|
|
||||||
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
|
|
||||||
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
|
|
||||||
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
|
|
||||||
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
|
|
||||||
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
|
||||||
360
llm/patches/09-lora.diff
Normal file
360
llm/patches/09-lora.diff
Normal file
@@ -0,0 +1,360 @@
|
|||||||
|
diff --git a/common/common.cpp b/common/common.cpp
|
||||||
|
index dbb724fb..c26fe6ee 100644
|
||||||
|
--- a/common/common.cpp
|
||||||
|
+++ b/common/common.cpp
|
||||||
|
@@ -2087,14 +2087,29 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
||||||
|
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
|
||||||
|
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
|
||||||
|
float lora_scale = std::get<1>(params.lora_adapter[i]);
|
||||||
|
+
|
||||||
|
+ // try to load as gguf
|
||||||
|
auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str());
|
||||||
|
if (adapter == nullptr) {
|
||||||
|
- fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
|
||||||
|
- llama_free(lctx);
|
||||||
|
- llama_free_model(model);
|
||||||
|
- return std::make_tuple(nullptr, nullptr);
|
||||||
|
+ fprintf(stderr, "%s: error: failed to apply lora adapter, trying ggla\n", __func__);
|
||||||
|
+
|
||||||
|
+ // if that fails, try loading as ggla for compatibility
|
||||||
|
+ int err = llama_model_apply_lora_from_file(model,
|
||||||
|
+ lora_adapter.c_str(),
|
||||||
|
+ lora_scale,
|
||||||
|
+ ((i > 0) || params.lora_base.empty())
|
||||||
|
+ ? NULL
|
||||||
|
+ : params.lora_base.c_str(),
|
||||||
|
+ params.n_threads);
|
||||||
|
+ if (err != 0) {
|
||||||
|
+ fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
|
||||||
|
+ llama_free(lctx);
|
||||||
|
+ llama_free_model(model);
|
||||||
|
+ return std::make_tuple(nullptr, nullptr);
|
||||||
|
+ }
|
||||||
|
+ } else {
|
||||||
|
+ llama_lora_adapter_set(lctx, adapter, lora_scale);
|
||||||
|
}
|
||||||
|
- llama_lora_adapter_set(lctx, adapter, lora_scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.ignore_eos) {
|
||||||
|
diff --git a/include/llama.h b/include/llama.h
|
||||||
|
index 93fd77ca..b0fb37a6 100644
|
||||||
|
--- a/include/llama.h
|
||||||
|
+++ b/include/llama.h
|
||||||
|
@@ -1160,6 +1160,20 @@ extern "C" {
|
||||||
|
|
||||||
|
LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
|
||||||
|
|
||||||
|
+ // Apply a LoRA adapter to a loaded model
|
||||||
|
+ // path_base_model is the path to a higher quality model to use as a base for
|
||||||
|
+ // the layers modified by the adapter. Can be NULL to use the current loaded model.
|
||||||
|
+ // The model needs to be reloaded before applying a new adapter, otherwise the adapter
|
||||||
|
+ // will be applied on top of the previous one
|
||||||
|
+ // Returns 0 on success
|
||||||
|
+ LLAMA_API int32_t llama_model_apply_lora_from_file(
|
||||||
|
+ const struct llama_model * model,
|
||||||
|
+ const char * path_lora,
|
||||||
|
+ float scale,
|
||||||
|
+ const char * path_base_model,
|
||||||
|
+ int32_t n_threads);
|
||||||
|
+
|
||||||
|
+
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||||
|
index 80a0dd0f..9d7b0e17 100644
|
||||||
|
--- a/src/llama.cpp
|
||||||
|
+++ b/src/llama.cpp
|
||||||
|
@@ -21880,3 +21880,290 @@ static void llama_log_callback_default(ggml_log_level level, const char * text,
|
||||||
|
fputs(text, stderr);
|
||||||
|
fflush(stderr);
|
||||||
|
}
|
||||||
|
+
|
||||||
|
+static int llama_apply_lora_from_file_internal(
|
||||||
|
+ const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads
|
||||||
|
+) {
|
||||||
|
+ LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);
|
||||||
|
+
|
||||||
|
+ const int64_t t_start_lora_us = ggml_time_us();
|
||||||
|
+
|
||||||
|
+ llama_file fin(path_lora, "rb");
|
||||||
|
+
|
||||||
|
+ // verify magic and version
|
||||||
|
+ {
|
||||||
|
+ uint32_t magic = fin.read_u32();
|
||||||
|
+ if (magic != LLAMA_FILE_MAGIC_GGLA) {
|
||||||
|
+ LLAMA_LOG_ERROR("%s: bad file magic\n", __func__);
|
||||||
|
+ return 1;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ uint32_t format_version = fin.read_u32();
|
||||||
|
+ if (format_version != 1) {
|
||||||
|
+ LLAMA_LOG_ERROR("%s: unsupported file version\n", __func__ );
|
||||||
|
+ return 1;
|
||||||
|
+ }
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ int32_t lora_r = fin.read_u32();
|
||||||
|
+ int32_t lora_alpha = fin.read_u32();
|
||||||
|
+ float scaling = scale * (float)lora_alpha / (float)lora_r;
|
||||||
|
+
|
||||||
|
+ LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
|
||||||
|
+
|
||||||
|
+ // load base model
|
||||||
|
+ std::unique_ptr<llama_model_loader> ml;
|
||||||
|
+ if (path_base_model) {
|
||||||
|
+ LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
|
||||||
|
+ ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*check_tensors*/ false, /*kv_overrides*/ nullptr));
|
||||||
|
+ ml->init_mappings(/*prefetch*/ false); // no prefetching
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ struct tensor_meta {
|
||||||
|
+ std::string name;
|
||||||
|
+ ggml_type type;
|
||||||
|
+ int32_t ne[2];
|
||||||
|
+ size_t offset;
|
||||||
|
+ };
|
||||||
|
+ std::map<std::string, tensor_meta> tensor_meta_map;
|
||||||
|
+
|
||||||
|
+ // load all tensor meta
|
||||||
|
+ while (true) {
|
||||||
|
+ if (fin.tell() == fin.size) {
|
||||||
|
+ // eof
|
||||||
|
+ break;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ int32_t n_dims;
|
||||||
|
+ int32_t name_len;
|
||||||
|
+ int32_t ftype;
|
||||||
|
+
|
||||||
|
+ fin.read_raw(&n_dims, sizeof(n_dims));
|
||||||
|
+ fin.read_raw(&name_len, sizeof(name_len));
|
||||||
|
+ fin.read_raw(&ftype, sizeof(ftype));
|
||||||
|
+
|
||||||
|
+ if (n_dims != 1 && n_dims != 2) {
|
||||||
|
+ LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims);
|
||||||
|
+ return 1;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ int32_t ne[2] = { 1, 1 };
|
||||||
|
+ for (int i = 0; i < n_dims; ++i) {
|
||||||
|
+ fin.read_raw(&ne[i], sizeof(ne[i]));
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ std::string name;
|
||||||
|
+ {
|
||||||
|
+ GGML_ASSERT(name_len < GGML_MAX_NAME);
|
||||||
|
+ char buf[GGML_MAX_NAME];
|
||||||
|
+ fin.read_raw(buf, name_len);
|
||||||
|
+ name = std::string(buf, name_len);
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ // check for lora suffix
|
||||||
|
+ std::string lora_suffix;
|
||||||
|
+ if (name.length() > 6) {
|
||||||
|
+ lora_suffix = name.substr(name.length() - 6);
|
||||||
|
+ }
|
||||||
|
+ if (lora_suffix != ".loraA" && lora_suffix != ".loraB") {
|
||||||
|
+ LLAMA_LOG_ERROR("%s: error: '%s' is not a lora tensor\n", __func__, name.c_str());
|
||||||
|
+ return 1;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ // tensor type
|
||||||
|
+ ggml_type wtype;
|
||||||
|
+ switch (ftype) {
|
||||||
|
+ case 0: wtype = GGML_TYPE_F32; break;
|
||||||
|
+ case 1: wtype = GGML_TYPE_F16; break;
|
||||||
|
+ default:
|
||||||
|
+ {
|
||||||
|
+ LLAMA_LOG_ERROR("%s: invalid tensor data type '%d'\n",
|
||||||
|
+ __func__, ftype);
|
||||||
|
+ return 1;
|
||||||
|
+ }
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ // data offset
|
||||||
|
+ size_t offset = fin.tell();
|
||||||
|
+ offset = (offset + 31) & -32;
|
||||||
|
+
|
||||||
|
+ // skip tensor data
|
||||||
|
+ fin.seek(offset + ggml_row_size(wtype, ne[0]) * ne[1], SEEK_SET);
|
||||||
|
+
|
||||||
|
+ tensor_meta_map.emplace(name, tensor_meta{ name, wtype, { ne[0], ne[1] }, offset });
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ bool warned = false;
|
||||||
|
+ int n_tensors = 0;
|
||||||
|
+
|
||||||
|
+ // apply
|
||||||
|
+ ggml_backend_t backend_cpu = ggml_backend_cpu_init();
|
||||||
|
+ if (backend_cpu == nullptr) {
|
||||||
|
+ LLAMA_LOG_ERROR("%s: error: failed to initialize cpu backend\n", __func__);
|
||||||
|
+ return 1;
|
||||||
|
+ }
|
||||||
|
+ ggml_backend_cpu_set_n_threads(backend_cpu, n_threads);
|
||||||
|
+
|
||||||
|
+ std::vector<no_init<uint8_t>> read_buf;
|
||||||
|
+ for (const auto & it : model.tensors_by_name) {
|
||||||
|
+ const std::string & base_name = it.first;
|
||||||
|
+ ggml_tensor * model_t = it.second;
|
||||||
|
+
|
||||||
|
+ if (tensor_meta_map.find(base_name + ".loraA") == tensor_meta_map.end() ||
|
||||||
|
+ tensor_meta_map.find(base_name + ".loraB") == tensor_meta_map.end()) {
|
||||||
|
+ continue;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ tensor_meta & metaA = tensor_meta_map.at(base_name + ".loraA");
|
||||||
|
+ tensor_meta & metaB = tensor_meta_map.at(base_name + ".loraB");
|
||||||
|
+
|
||||||
|
+ ggml_init_params lora_init_params = {
|
||||||
|
+ /* .mem_size */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
|
||||||
|
+ /* .mem_buffer */ nullptr,
|
||||||
|
+ /* .no_alloc */ true,
|
||||||
|
+ };
|
||||||
|
+ ggml_context * lora_ctx = ggml_init(lora_init_params);
|
||||||
|
+ if (lora_ctx == nullptr) {
|
||||||
|
+ LLAMA_LOG_ERROR("%s: error: failed to initialize lora context\n", __func__);
|
||||||
|
+ ggml_backend_free(backend_cpu);
|
||||||
|
+ return 1;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ // create tensors
|
||||||
|
+ ggml_tensor * loraA = ggml_new_tensor_2d(lora_ctx, metaA.type, metaA.ne[0], metaA.ne[1]);
|
||||||
|
+ ggml_tensor * loraB = ggml_new_tensor_2d(lora_ctx, metaB.type, metaB.ne[0], metaB.ne[1]);
|
||||||
|
+ ggml_set_name(loraA, metaA.name.c_str());
|
||||||
|
+ ggml_set_name(loraB, metaB.name.c_str());
|
||||||
|
+
|
||||||
|
+ ggml_tensor * base_t;
|
||||||
|
+ if (ml) {
|
||||||
|
+ if (!ml->get_tensor_meta(base_name.c_str())) {
|
||||||
|
+ LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
|
||||||
|
+ return 1;
|
||||||
|
+ }
|
||||||
|
+ base_t = ggml_dup_tensor(lora_ctx, ml->get_tensor_meta(base_name.c_str()));
|
||||||
|
+ } else {
|
||||||
|
+ base_t = ggml_dup_tensor(lora_ctx, model_t);
|
||||||
|
+ }
|
||||||
|
+ ggml_set_name(base_t, base_name.c_str());
|
||||||
|
+
|
||||||
|
+ // allocate in backend buffer
|
||||||
|
+ ggml_backend_buffer_t lora_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type());
|
||||||
|
+ if (lora_buf == nullptr) {
|
||||||
|
+ LLAMA_LOG_ERROR("%s: error: failed to allocate lora tensors\n", __func__);
|
||||||
|
+ return 1;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ // load tensor data
|
||||||
|
+ auto load_tensor = [&read_buf, &fin](const tensor_meta & tensor_meta, ggml_tensor * tensor) {
|
||||||
|
+ read_buf.resize(ggml_nbytes(tensor));
|
||||||
|
+ fin.seek(tensor_meta.offset, SEEK_SET);
|
||||||
|
+ fin.read_raw(read_buf.data(), ggml_nbytes(tensor));
|
||||||
|
+ ggml_backend_tensor_set(tensor, read_buf.data(), 0, read_buf.size());
|
||||||
|
+ };
|
||||||
|
+ load_tensor(metaA, loraA);
|
||||||
|
+ load_tensor(metaB, loraB);
|
||||||
|
+
|
||||||
|
+ // load base model tensor data
|
||||||
|
+ if (ml) {
|
||||||
|
+ ml->load_data_for(base_t);
|
||||||
|
+ } else {
|
||||||
|
+ ggml_backend_tensor_copy(model_t, base_t);
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ if (ggml_is_quantized(base_t->type) && !warned) {
|
||||||
|
+ LLAMA_LOG_WARN("%s: warning: using a lora adapter with a quantized model may result in poor quality, "
|
||||||
|
+ "use a f16 or f32 base model with --lora-base\n", __func__);
|
||||||
|
+ warned = true;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) {
|
||||||
|
+ LLAMA_LOG_ERROR("%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
|
||||||
|
+ " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]);
|
||||||
|
+ ggml_free(lora_ctx);
|
||||||
|
+ ggml_backend_buffer_free(lora_buf);
|
||||||
|
+ ggml_backend_free(backend_cpu);
|
||||||
|
+ return 1;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ auto build_lora_graph = [&]() {
|
||||||
|
+ // w = w + BA*s
|
||||||
|
+ ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB);
|
||||||
|
+ ggml_set_name(BA, "BA");
|
||||||
|
+
|
||||||
|
+ if (scaling != 1.0f) {
|
||||||
|
+ BA = ggml_scale(lora_ctx, BA, scaling);
|
||||||
|
+ ggml_set_name(BA, "BA_scaled");
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ ggml_tensor * r;
|
||||||
|
+ r = ggml_add_inplace(lora_ctx, base_t, BA);
|
||||||
|
+ ggml_set_name(r, "r_add");
|
||||||
|
+
|
||||||
|
+ if (base_t->type != model_t->type) {
|
||||||
|
+ // convert the result to the model type
|
||||||
|
+ r = ggml_cast(lora_ctx, r, model_t->type);
|
||||||
|
+ ggml_set_name(r, "r_cast");
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ return r;
|
||||||
|
+ };
|
||||||
|
+
|
||||||
|
+ ggml_cgraph * gf = ggml_new_graph(lora_ctx);
|
||||||
|
+ ggml_tensor * r = build_lora_graph();
|
||||||
|
+ ggml_build_forward_expand(gf, r);
|
||||||
|
+
|
||||||
|
+ ggml_backend_buffer_t graph_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type());
|
||||||
|
+ if (graph_buf == nullptr) {
|
||||||
|
+ LLAMA_LOG_ERROR("%s: error: failed to allocate graph tensors\n", __func__);
|
||||||
|
+ ggml_free(lora_ctx);
|
||||||
|
+ ggml_backend_buffer_free(lora_buf);
|
||||||
|
+ ggml_backend_free(backend_cpu);
|
||||||
|
+ return 1;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ ggml_backend_graph_compute(backend_cpu, gf);
|
||||||
|
+
|
||||||
|
+ ggml_backend_tensor_set(model_t, r->data, 0, ggml_nbytes(r));
|
||||||
|
+
|
||||||
|
+#if 0
|
||||||
|
+ // TODO: use scheduler with fallback to CPU for less copies between CPU and GPU
|
||||||
|
+ //ggml_backend_sched_t sched = ggml_backend_sched_new(backends.data(), backends.size(), GGML_DEFAULT_GRAPH_SIZE);
|
||||||
|
+
|
||||||
|
+ // sched compute
|
||||||
|
+ ggml_build_forward_expand(gf, build_graph());
|
||||||
|
+ ggml_backend_sched_init_measure(sched, gf);
|
||||||
|
+
|
||||||
|
+ // create the graph again, since the previous one was destroyed by the measure
|
||||||
|
+ ggml_graph_clear(gf);
|
||||||
|
+ ggml_build_forward_expand(gf, build_graph());
|
||||||
|
+ ggml_backend_sched_graph_compute(sched, gf);
|
||||||
|
+ ggml_backend_sched_free(sched);
|
||||||
|
+#endif
|
||||||
|
+
|
||||||
|
+ ggml_backend_buffer_free(lora_buf);
|
||||||
|
+ ggml_backend_buffer_free(graph_buf);
|
||||||
|
+ ggml_free(lora_ctx);
|
||||||
|
+
|
||||||
|
+ n_tensors++;
|
||||||
|
+ if (n_tensors % 4 == 0) {
|
||||||
|
+ LLAMA_LOG_INFO(".");
|
||||||
|
+ }
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ ggml_backend_free(backend_cpu);
|
||||||
|
+
|
||||||
|
+ const int64_t t_lora_us = ggml_time_us() - t_start_lora_us;
|
||||||
|
+ LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0);
|
||||||
|
+
|
||||||
|
+ return 0;
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
+int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, float scale, const char * path_base_model, int32_t n_threads) {
|
||||||
|
+ try {
|
||||||
|
+ return llama_apply_lora_from_file_internal(*model, path_lora, scale, path_base_model, n_threads);
|
||||||
|
+ } catch (const std::exception & err) {
|
||||||
|
+ LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
|
||||||
|
+ return 1;
|
||||||
|
+ }
|
||||||
|
+}
|
||||||
|
\ No newline at end of file
|
||||||
@@ -385,8 +385,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
filteredEnv := []string{}
|
filteredEnv := []string{}
|
||||||
for _, ev := range s.cmd.Env {
|
for _, ev := range s.cmd.Env {
|
||||||
if strings.HasPrefix(ev, "CUDA_") ||
|
if strings.HasPrefix(ev, "CUDA_") ||
|
||||||
|
strings.HasPrefix(ev, "ROCR_") ||
|
||||||
strings.HasPrefix(ev, "ROCM_") ||
|
strings.HasPrefix(ev, "ROCM_") ||
|
||||||
strings.HasPrefix(ev, "HIP_") ||
|
strings.HasPrefix(ev, "HIP_") ||
|
||||||
|
strings.HasPrefix(ev, "GPU_") ||
|
||||||
strings.HasPrefix(ev, "HSA_") ||
|
strings.HasPrefix(ev, "HSA_") ||
|
||||||
strings.HasPrefix(ev, "GGML_") ||
|
strings.HasPrefix(ev, "GGML_") ||
|
||||||
strings.HasPrefix(ev, "PATH=") ||
|
strings.HasPrefix(ev, "PATH=") ||
|
||||||
@@ -415,7 +417,17 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
|
|
||||||
// reap subprocess when it exits
|
// reap subprocess when it exits
|
||||||
go func() {
|
go func() {
|
||||||
s.done <- s.cmd.Wait()
|
err := s.cmd.Wait()
|
||||||
|
// Favor a more detailed message over the process exit status
|
||||||
|
if err != nil && s.status != nil && s.status.LastErrMsg != "" {
|
||||||
|
slog.Debug("llama runner terminated", "error", err)
|
||||||
|
if strings.Contains(s.status.LastErrMsg, "unknown model") {
|
||||||
|
s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
|
||||||
|
}
|
||||||
|
s.done <- fmt.Errorf(s.status.LastErrMsg)
|
||||||
|
} else {
|
||||||
|
s.done <- err
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
@@ -578,14 +590,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
|||||||
slog.Warn("client connection closed before server finished loading, aborting load")
|
slog.Warn("client connection closed before server finished loading, aborting load")
|
||||||
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
|
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
|
||||||
case err := <-s.done:
|
case err := <-s.done:
|
||||||
msg := ""
|
return fmt.Errorf("llama runner process has terminated: %w", err)
|
||||||
if s.status != nil && s.status.LastErrMsg != "" {
|
|
||||||
msg = s.status.LastErrMsg
|
|
||||||
}
|
|
||||||
if strings.Contains(msg, "unknown model") {
|
|
||||||
return fmt.Errorf("this model is not supported by your version of Ollama. You may need to upgrade")
|
|
||||||
}
|
|
||||||
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
|
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
if time.Now().After(stallTimer) {
|
if time.Now().After(stallTimer) {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -29,8 +30,9 @@ type ErrorResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content any `json:"content"`
|
Content any `json:"content"`
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Choice struct {
|
type Choice struct {
|
||||||
@@ -78,6 +80,7 @@ type ChatCompletionRequest struct {
|
|||||||
PresencePenalty *float64 `json:"presence_penalty_penalty"`
|
PresencePenalty *float64 `json:"presence_penalty_penalty"`
|
||||||
TopP *float64 `json:"top_p"`
|
TopP *float64 `json:"top_p"`
|
||||||
ResponseFormat *ResponseFormat `json:"response_format"`
|
ResponseFormat *ResponseFormat `json:"response_format"`
|
||||||
|
Tools []api.Tool `json:"tools"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletion struct {
|
type ChatCompletion struct {
|
||||||
@@ -111,6 +114,7 @@ type CompletionRequest struct {
|
|||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
Temperature *float32 `json:"temperature"`
|
Temperature *float32 `json:"temperature"`
|
||||||
TopP float32 `json:"top_p"`
|
TopP float32 `json:"top_p"`
|
||||||
|
Suffix string `json:"suffix"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Completion struct {
|
type Completion struct {
|
||||||
@@ -132,6 +136,15 @@ type CompletionChunk struct {
|
|||||||
SystemFingerprint string `json:"system_fingerprint"`
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments string `json:"arguments"`
|
||||||
|
} `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
@@ -170,7 +183,31 @@ func NewError(code int, message string) ErrorResponse {
|
|||||||
return ErrorResponse{Error{Type: etype, Message: message}}
|
return ErrorResponse{Error{Type: etype, Message: message}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toolCallId() string {
|
||||||
|
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
|
b := make([]byte, 8)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
||||||
|
}
|
||||||
|
return "call_" + strings.ToLower(string(b))
|
||||||
|
}
|
||||||
|
|
||||||
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||||
|
toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
|
||||||
|
for i, tc := range r.Message.ToolCalls {
|
||||||
|
toolCalls[i].ID = toolCallId()
|
||||||
|
toolCalls[i].Type = "function"
|
||||||
|
toolCalls[i].Function.Name = tc.Function.Name
|
||||||
|
|
||||||
|
args, err := json.Marshal(tc.Function.Arguments)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("could not marshall function arguments to json", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls[i].Function.Arguments = string(args)
|
||||||
|
}
|
||||||
|
|
||||||
return ChatCompletion{
|
return ChatCompletion{
|
||||||
Id: id,
|
Id: id,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
@@ -179,7 +216,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|||||||
SystemFingerprint: "fp_ollama",
|
SystemFingerprint: "fp_ollama",
|
||||||
Choices: []Choice{{
|
Choices: []Choice{{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Message: Message{Role: r.Message.Role, Content: r.Message.Content},
|
Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
|
||||||
FinishReason: func(reason string) *string {
|
FinishReason: func(reason string) *string {
|
||||||
if len(reason) > 0 {
|
if len(reason) > 0 {
|
||||||
return &reason
|
return &reason
|
||||||
@@ -188,7 +225,6 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|||||||
}(r.DoneReason),
|
}(r.DoneReason),
|
||||||
}},
|
}},
|
||||||
Usage: Usage{
|
Usage: Usage{
|
||||||
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
|
|
||||||
PromptTokens: r.PromptEvalCount,
|
PromptTokens: r.PromptEvalCount,
|
||||||
CompletionTokens: r.EvalCount,
|
CompletionTokens: r.EvalCount,
|
||||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||||
@@ -234,7 +270,6 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
|
|||||||
}(r.DoneReason),
|
}(r.DoneReason),
|
||||||
}},
|
}},
|
||||||
Usage: Usage{
|
Usage: Usage{
|
||||||
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
|
|
||||||
PromptTokens: r.PromptEvalCount,
|
PromptTokens: r.PromptEvalCount,
|
||||||
CompletionTokens: r.EvalCount,
|
CompletionTokens: r.EvalCount,
|
||||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||||
@@ -316,7 +351,6 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|||||||
case string:
|
case string:
|
||||||
messages = append(messages, api.Message{Role: msg.Role, Content: content})
|
messages = append(messages, api.Message{Role: msg.Role, Content: content})
|
||||||
case []any:
|
case []any:
|
||||||
message := api.Message{Role: msg.Role}
|
|
||||||
for _, c := range content {
|
for _, c := range content {
|
||||||
data, ok := c.(map[string]any)
|
data, ok := c.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -328,7 +362,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, fmt.Errorf("invalid message format")
|
||||||
}
|
}
|
||||||
message.Content = text
|
messages = append(messages, api.Message{Role: msg.Role, Content: text})
|
||||||
case "image_url":
|
case "image_url":
|
||||||
var url string
|
var url string
|
||||||
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
||||||
@@ -360,14 +394,26 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, fmt.Errorf("invalid message format")
|
||||||
}
|
}
|
||||||
message.Images = append(message.Images, img)
|
|
||||||
|
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid message format")
|
return nil, fmt.Errorf("invalid message format")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
messages = append(messages, message)
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid message content type: %T", content)
|
if msg.ToolCalls == nil {
|
||||||
|
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
|
||||||
|
for i, tc := range msg.ToolCalls {
|
||||||
|
toolCalls[i].Function.Name = tc.Function.Name
|
||||||
|
err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid tool call arguments")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -425,6 +471,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|||||||
Format: format,
|
Format: format,
|
||||||
Options: options,
|
Options: options,
|
||||||
Stream: &r.Stream,
|
Stream: &r.Stream,
|
||||||
|
Tools: r.Tools,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -475,6 +522,7 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
|||||||
Prompt: r.Prompt,
|
Prompt: r.Prompt,
|
||||||
Options: options,
|
Options: options,
|
||||||
Stream: &r.Stream,
|
Stream: &r.Stream,
|
||||||
|
Suffix: r.Suffix,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -829,6 +877,7 @@ func ChatMiddleware() gin.HandlerFunc {
|
|||||||
chatReq, err := fromChatRequest(req)
|
chatReq, err := fromChatRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||||
|
|||||||
@@ -20,108 +20,59 @@ const prefix = `data:image/jpeg;base64,`
|
|||||||
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||||
const imageURL = prefix + image
|
const imageURL = prefix + image
|
||||||
|
|
||||||
func TestMiddlewareRequests(t *testing.T) {
|
func prepareRequest(req *http.Request, body any) {
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
}
|
||||||
|
|
||||||
|
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
err := json.Unmarshal(bodyBytes, capturedRequest)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatMiddleware(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
Name string
|
Name string
|
||||||
Method string
|
|
||||||
Path string
|
|
||||||
Handler func() gin.HandlerFunc
|
|
||||||
Setup func(t *testing.T, req *http.Request)
|
Setup func(t *testing.T, req *http.Request)
|
||||||
Expected func(t *testing.T, req *http.Request)
|
Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
|
||||||
}
|
}
|
||||||
|
|
||||||
var capturedRequest *http.Request
|
var capturedRequest *api.ChatRequest
|
||||||
|
|
||||||
captureRequestMiddleware := func() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
capturedRequest = c.Request
|
|
||||||
c.Next()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
Name: "chat handler",
|
Name: "chat handler",
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/chat",
|
|
||||||
Handler: ChatMiddleware,
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
body := ChatCompletionRequest{
|
body := ChatCompletionRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
Messages: []Message{{Role: "user", Content: "Hello"}},
|
Messages: []Message{{Role: "user", Content: "Hello"}},
|
||||||
}
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *http.Request) {
|
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||||
var chatReq api.ChatRequest
|
if resp.Code != http.StatusOK {
|
||||||
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
|
t.Fatalf("expected 200, got %d", resp.Code)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if chatReq.Messages[0].Role != "user" {
|
if req.Messages[0].Role != "user" {
|
||||||
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
|
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
|
||||||
}
|
}
|
||||||
|
|
||||||
if chatReq.Messages[0].Content != "Hello" {
|
if req.Messages[0].Content != "Hello" {
|
||||||
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
|
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "completions handler",
|
Name: "chat handler with image content",
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/generate",
|
|
||||||
Handler: CompletionsMiddleware,
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
temp := float32(0.8)
|
|
||||||
body := CompletionRequest{
|
|
||||||
Model: "test-model",
|
|
||||||
Prompt: "Hello",
|
|
||||||
Temperature: &temp,
|
|
||||||
Stop: []string{"\n", "stop"},
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, req *http.Request) {
|
|
||||||
var genReq api.GenerateRequest
|
|
||||||
if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if genReq.Prompt != "Hello" {
|
|
||||||
t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
if genReq.Options["temperature"] != 1.6 {
|
|
||||||
t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
|
|
||||||
}
|
|
||||||
|
|
||||||
stopTokens, ok := genReq.Options["stop"].([]any)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected stop tokens to be a list")
|
|
||||||
}
|
|
||||||
|
|
||||||
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
|
|
||||||
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "chat handler with image content",
|
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/chat",
|
|
||||||
Handler: ChatMiddleware,
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
body := ChatCompletionRequest{
|
body := ChatCompletionRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
@@ -134,87 +85,254 @@ func TestMiddlewareRequests(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *http.Request) {
|
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||||
var chatReq api.ChatRequest
|
if resp.Code != http.StatusOK {
|
||||||
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
|
t.Fatalf("expected 200, got %d", resp.Code)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if chatReq.Messages[0].Role != "user" {
|
if req.Messages[0].Role != "user" {
|
||||||
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
|
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
|
||||||
}
|
}
|
||||||
|
|
||||||
if chatReq.Messages[0].Content != "Hello" {
|
if req.Messages[0].Content != "Hello" {
|
||||||
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
|
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
|
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
|
||||||
|
|
||||||
if !bytes.Equal(chatReq.Messages[0].Images[0], img) {
|
if req.Messages[1].Role != "user" {
|
||||||
t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0])
|
t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(req.Messages[1].Images[0], img) {
|
||||||
|
t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "embed handler single input",
|
Name: "chat handler with tools",
|
||||||
Method: http.MethodPost,
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
Path: "/api/embed",
|
body := ChatCompletionRequest{
|
||||||
Handler: EmbeddingsMiddleware,
|
Model: "test-model",
|
||||||
|
Messages: []Message{
|
||||||
|
{Role: "user", Content: "What's the weather like in Paris Today?"},
|
||||||
|
{Role: "assistant", ToolCalls: []ToolCall{{
|
||||||
|
ID: "id",
|
||||||
|
Type: "function",
|
||||||
|
Function: struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments string `json:"arguments"`
|
||||||
|
}{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
|
||||||
|
},
|
||||||
|
}}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if resp.Code != 200 {
|
||||||
|
t.Fatalf("expected 200, got %d", resp.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Messages[0].Content != "What's the weather like in Paris Today?" {
|
||||||
|
t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
|
||||||
|
t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
|
||||||
|
t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "chat handler error forwarding",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := ChatCompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []Message{{Role: "user", Content: 2}},
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if resp.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", resp.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(resp.Body.String(), "invalid message content type") {
|
||||||
|
t.Fatalf("error was not forwarded")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||||
|
router.Handle(http.MethodPost, "/api/chat", endpoint)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
|
||||||
|
|
||||||
|
tc.Setup(t, req)
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
tc.Expected(t, capturedRequest, resp)
|
||||||
|
|
||||||
|
capturedRequest = nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompletionsMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
Name string
|
||||||
|
Setup func(t *testing.T, req *http.Request)
|
||||||
|
Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
|
||||||
|
}
|
||||||
|
|
||||||
|
var capturedRequest *api.GenerateRequest
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
Name: "completions handler",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
temp := float32(0.8)
|
||||||
|
body := CompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Temperature: &temp,
|
||||||
|
Stop: []string{"\n", "stop"},
|
||||||
|
Suffix: "suffix",
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if req.Prompt != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %s", req.Prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Options["temperature"] != 1.6 {
|
||||||
|
t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
|
||||||
|
}
|
||||||
|
|
||||||
|
stopTokens, ok := req.Options["stop"].([]any)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected stop tokens to be a list")
|
||||||
|
}
|
||||||
|
|
||||||
|
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
|
||||||
|
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Suffix != "suffix" {
|
||||||
|
t.Fatalf("expected 'suffix', got %s", req.Suffix)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "completions handler error forwarding",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := CompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Temperature: nil,
|
||||||
|
Stop: []int{1, 2},
|
||||||
|
Suffix: "suffix",
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if resp.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", resp.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
|
||||||
|
t.Fatalf("error was not forwarded")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||||
|
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
|
||||||
|
|
||||||
|
tc.Setup(t, req)
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
tc.Expected(t, capturedRequest, resp)
|
||||||
|
|
||||||
|
capturedRequest = nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddingsMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
Name string
|
||||||
|
Setup func(t *testing.T, req *http.Request)
|
||||||
|
Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
|
||||||
|
}
|
||||||
|
|
||||||
|
var capturedRequest *api.EmbedRequest
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
Name: "embed handler single input",
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
body := EmbedRequest{
|
body := EmbedRequest{
|
||||||
Input: "Hello",
|
Input: "Hello",
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
}
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *http.Request) {
|
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
||||||
var embedReq api.EmbedRequest
|
if req.Input != "Hello" {
|
||||||
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
|
t.Fatalf("expected 'Hello', got %s", req.Input)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if embedReq.Input != "Hello" {
|
if req.Model != "test-model" {
|
||||||
t.Fatalf("expected 'Hello', got %s", embedReq.Input)
|
t.Fatalf("expected 'test-model', got %s", req.Model)
|
||||||
}
|
|
||||||
|
|
||||||
if embedReq.Model != "test-model" {
|
|
||||||
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "embed handler batch input",
|
Name: "embed handler batch input",
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/embed",
|
|
||||||
Handler: EmbeddingsMiddleware,
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
body := EmbedRequest{
|
body := EmbedRequest{
|
||||||
Input: []string{"Hello", "World"},
|
Input: []string{"Hello", "World"},
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
}
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *http.Request) {
|
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
||||||
var embedReq api.EmbedRequest
|
input, ok := req.Input.([]any)
|
||||||
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
input, ok := embedReq.Input.([]any)
|
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("expected input to be a list")
|
t.Fatalf("expected input to be a list")
|
||||||
@@ -228,36 +346,52 @@ func TestMiddlewareRequests(t *testing.T) {
|
|||||||
t.Fatalf("expected 'World', got %s", input[1])
|
t.Fatalf("expected 'World', got %s", input[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
if embedReq.Model != "test-model" {
|
if req.Model != "test-model" {
|
||||||
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
|
t.Fatalf("expected 'test-model', got %s", req.Model)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "embed handler error forwarding",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := EmbedRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if resp.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", resp.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(resp.Body.String(), "invalid input") {
|
||||||
|
t.Fatalf("error was not forwarded")
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
router := gin.New()
|
|
||||||
|
|
||||||
endpoint := func(c *gin.Context) {
|
endpoint := func(c *gin.Context) {
|
||||||
c.Status(http.StatusOK)
|
c.Status(http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||||
|
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
router = gin.New()
|
req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
|
||||||
router.Use(captureRequestMiddleware())
|
|
||||||
router.Use(tc.Handler())
|
|
||||||
router.Handle(tc.Method, tc.Path, endpoint)
|
|
||||||
req, _ := http.NewRequest(tc.Method, tc.Path, nil)
|
|
||||||
|
|
||||||
if tc.Setup != nil {
|
tc.Setup(t, req)
|
||||||
tc.Setup(t, req)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
router.ServeHTTP(resp, req)
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
tc.Expected(t, capturedRequest)
|
tc.Expected(t, capturedRequest, resp)
|
||||||
|
|
||||||
|
capturedRequest = nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -275,36 +409,6 @@ func TestMiddlewareResponses(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
|
||||||
Name: "completions handler error forwarding",
|
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/generate",
|
|
||||||
TestPath: "/api/generate",
|
|
||||||
Handler: CompletionsMiddleware,
|
|
||||||
Endpoint: func(c *gin.Context) {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
|
||||||
},
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
body := CompletionRequest{
|
|
||||||
Model: "test-model",
|
|
||||||
Prompt: "Hello",
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
|
||||||
if resp.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected 400, got %d", resp.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(resp.Body.String(), `"invalid request"`) {
|
|
||||||
t.Fatalf("error was not forwarded")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
Name: "list handler",
|
Name: "list handler",
|
||||||
Method: http.MethodGet,
|
Method: http.MethodGet,
|
||||||
@@ -321,8 +425,6 @@ func TestMiddlewareResponses(t *testing.T) {
|
|||||||
})
|
})
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||||
assert.Equal(t, http.StatusOK, resp.Code)
|
|
||||||
|
|
||||||
var listResp ListCompletion
|
var listResp ListCompletion
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -386,6 +488,8 @@ func TestMiddlewareResponses(t *testing.T) {
|
|||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
router.ServeHTTP(resp, req)
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, resp.Code)
|
||||||
|
|
||||||
tc.Expected(t, resp)
|
tc.Expected(t, resp)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
|
"math/rand/v2"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
@@ -141,6 +142,32 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||||||
b.err = b.run(ctx, requestURL, opts)
|
b.err = b.run(ctx, requestURL, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newBackoff(maxBackoff time.Duration) func(ctx context.Context) error {
|
||||||
|
var n int
|
||||||
|
return func(ctx context.Context) error {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
n++
|
||||||
|
|
||||||
|
// n^2 backoff timer is a little smoother than the
|
||||||
|
// common choice of 2^n.
|
||||||
|
d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
|
||||||
|
// Randomize the delay between 0.5-1.5 x msec, in order
|
||||||
|
// to prevent accidental "thundering herd" problems.
|
||||||
|
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
|
||||||
|
t := time.NewTimer(d)
|
||||||
|
defer t.Stop()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-t.C:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
|
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
|
||||||
defer blobDownloadManager.Delete(b.Digest)
|
defer blobDownloadManager.Delete(b.Digest)
|
||||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||||
@@ -153,6 +180,52 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||||||
|
|
||||||
_ = file.Truncate(b.Total)
|
_ = file.Truncate(b.Total)
|
||||||
|
|
||||||
|
directURL, err := func() (*url.URL, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
backoff := newBackoff(10 * time.Second)
|
||||||
|
for {
|
||||||
|
// shallow clone opts to be used in the closure
|
||||||
|
// without affecting the outer opts.
|
||||||
|
newOpts := new(registryOptions)
|
||||||
|
*newOpts = *opts
|
||||||
|
|
||||||
|
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||||
|
if len(via) > 10 {
|
||||||
|
return errors.New("maxium redirects exceeded (10) for directURL")
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the hostname is the same, allow the redirect
|
||||||
|
if req.URL.Hostname() == requestURL.Hostname() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// stop at the first redirect that is not
|
||||||
|
// the same hostname as the original
|
||||||
|
// request.
|
||||||
|
return http.ErrUseLastResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, nil, nil, newOpts)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to get direct URL; backing off and retrying", "err", err)
|
||||||
|
if err := backoff(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusTemporaryRedirect {
|
||||||
|
return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return resp.Location()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
g, inner := errgroup.WithContext(ctx)
|
g, inner := errgroup.WithContext(ctx)
|
||||||
g.SetLimit(numDownloadParts)
|
g.SetLimit(numDownloadParts)
|
||||||
for i := range b.Parts {
|
for i := range b.Parts {
|
||||||
@@ -165,7 +238,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||||||
var err error
|
var err error
|
||||||
for try := 0; try < maxRetries; try++ {
|
for try := 0; try < maxRetries; try++ {
|
||||||
w := io.NewOffsetWriter(file, part.StartsAt())
|
w := io.NewOffsetWriter(file, part.StartsAt())
|
||||||
err = b.downloadChunk(inner, requestURL, w, part, opts)
|
err = b.downloadChunk(inner, directURL, w, part, opts)
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
|
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
|
||||||
// return immediately if the context is canceled or the device is out of space
|
// return immediately if the context is canceled or the device is out of space
|
||||||
|
|||||||
@@ -54,6 +54,8 @@ type registryOptions struct {
|
|||||||
Username string
|
Username string
|
||||||
Password string
|
Password string
|
||||||
Token string
|
Token string
|
||||||
|
|
||||||
|
CheckRedirect func(req *http.Request, via []*http.Request) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
@@ -492,6 +494,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
|||||||
layers = append(layers, baseLayer.Layer)
|
layers = append(layers, baseLayer.Layer)
|
||||||
}
|
}
|
||||||
case "license", "template", "system":
|
case "license", "template", "system":
|
||||||
|
if c.Name == "template" {
|
||||||
|
if _, err := template.Parse(c.Args); err != nil {
|
||||||
|
return fmt.Errorf("%w: %s", errBadTemplate, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if c.Name != "license" {
|
if c.Name != "license" {
|
||||||
// replace
|
// replace
|
||||||
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
||||||
@@ -1125,7 +1133,9 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
|
|||||||
req.ContentLength = contentLength
|
req.ContentLength = contentLength
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
resp, err := (&http.Client{
|
||||||
|
CheckRedirect: regOpts.CheckRedirect,
|
||||||
|
}).Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -263,13 +263,27 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
|||||||
if t, err := template.Named(s); err != nil {
|
if t, err := template.Named(s); err != nil {
|
||||||
slog.Debug("template detection", "error", err)
|
slog.Debug("template detection", "error", err)
|
||||||
} else {
|
} else {
|
||||||
tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpl.status = fmt.Sprintf("using autodetected template %s", t.Name)
|
layer.status = fmt.Sprintf("using autodetected template %s", t.Name)
|
||||||
layers = append(layers, &layerGGML{tmpl, nil})
|
layers = append(layers, &layerGGML{layer, nil})
|
||||||
|
|
||||||
|
if t.Parameters != nil {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(t.Parameters); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
layers = append(layers, &layerGGML{layer, nil})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -311,12 +325,14 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := tmpl.Execute(&b, map[string][]map[string]any{
|
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
|
||||||
"ToolCalls": {
|
"ToolCalls": {
|
||||||
{
|
{
|
||||||
"Function": map[string]any{
|
Function: api.ToolCallFunction{
|
||||||
"Name": "@@name@@",
|
Name: "@@name@@",
|
||||||
"Arguments": "@@arguments@@",
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"@@argument@@": 1,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -324,7 +340,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
var kv map[string]string
|
var kv map[string]any
|
||||||
// execute the subtree with placeholders to identify the keys
|
// execute the subtree with placeholders to identify the keys
|
||||||
// trim any commands that might exist in the template
|
// trim any commands that might exist in the template
|
||||||
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
|
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
|
||||||
@@ -334,17 +350,23 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
|||||||
// find the keys that correspond to the name and arguments fields
|
// find the keys that correspond to the name and arguments fields
|
||||||
var name, arguments string
|
var name, arguments string
|
||||||
for k, v := range kv {
|
for k, v := range kv {
|
||||||
switch v {
|
switch v.(type) {
|
||||||
case "@@name@@":
|
case string:
|
||||||
name = k
|
name = k
|
||||||
case "@@arguments@@":
|
case map[string]any:
|
||||||
arguments = k
|
arguments = k
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if name == "" || arguments == "" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
var objs []map[string]any
|
var objs []map[string]any
|
||||||
for offset := 0; offset < len(s); {
|
for offset := 0; offset < len(s); {
|
||||||
if err := json.NewDecoder(strings.NewReader(s[offset:])).Decode(&objs); errors.Is(err, io.EOF) {
|
var obj map[string]any
|
||||||
|
decoder := json.NewDecoder(strings.NewReader(s[offset:]))
|
||||||
|
if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||||
break
|
break
|
||||||
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
|
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
|
||||||
// skip over any syntax errors
|
// skip over any syntax errors
|
||||||
@@ -353,26 +375,44 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
|||||||
// skip over any unmarshalable types
|
// skip over any unmarshalable types
|
||||||
offset += int(unmarshalType.Offset)
|
offset += int(unmarshalType.Offset)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
|
slog.Error("parseToolCalls", "error", err)
|
||||||
return nil, false
|
return nil, false
|
||||||
} else {
|
} else {
|
||||||
// break when an object is decoded
|
offset += int(decoder.InputOffset())
|
||||||
break
|
|
||||||
|
// collect all nested objects
|
||||||
|
var collect func(any) []map[string]any
|
||||||
|
collect = func(obj any) (all []map[string]any) {
|
||||||
|
switch o := obj.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
all = append(all, o)
|
||||||
|
for _, v := range o {
|
||||||
|
all = append(all, collect(v)...)
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for _, v := range o {
|
||||||
|
all = append(all, collect(v)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return all
|
||||||
|
}
|
||||||
|
objs = append(objs, collect(obj)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var toolCalls []api.ToolCall
|
var toolCalls []api.ToolCall
|
||||||
for _, kv := range objs {
|
for _, kv := range objs {
|
||||||
var call api.ToolCall
|
n, nok := kv[name].(string)
|
||||||
for k, v := range kv {
|
a, aok := kv[arguments].(map[string]any)
|
||||||
switch k {
|
if nok && aok {
|
||||||
case name:
|
toolCalls = append(toolCalls, api.ToolCall{
|
||||||
call.Function.Name = v.(string)
|
Function: api.ToolCallFunction{
|
||||||
case arguments:
|
Name: n,
|
||||||
call.Function.Arguments = v.(map[string]any)
|
Arguments: a,
|
||||||
}
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
toolCalls = append(toolCalls, call)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return toolCalls, len(toolCalls) > 0
|
return toolCalls, len(toolCalls) > 0
|
||||||
|
|||||||
@@ -115,11 +115,6 @@ func TestExtractFromZipFile(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type function struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Arguments map[string]any `json:"arguments"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
@@ -167,6 +162,11 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
|
|||||||
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||||
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||||
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||||
|
{"llama3-groq-tool-use", `<tool_call>
|
||||||
|
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
|
||||||
|
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
|
||||||
|
</tool_call>`, true},
|
||||||
|
{"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
|
||||||
}
|
}
|
||||||
|
|
||||||
var tools []api.Tool
|
var tools []api.Tool
|
||||||
@@ -181,18 +181,18 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
|
|||||||
|
|
||||||
calls := []api.ToolCall{
|
calls := []api.ToolCall{
|
||||||
{
|
{
|
||||||
Function: function{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_current_weather",
|
Name: "get_current_weather",
|
||||||
Arguments: map[string]any{
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
"format": "fahrenheit",
|
"format": "fahrenheit",
|
||||||
"location": "San Francisco, CA",
|
"location": "San Francisco, CA",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: function{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_current_weather",
|
Name: "get_current_weather",
|
||||||
Arguments: map[string]any{
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
"format": "celsius",
|
"format": "celsius",
|
||||||
"location": "Toronto, Canada",
|
"location": "Toronto, Canada",
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var errRequired = errors.New("is required")
|
var errRequired = errors.New("is required")
|
||||||
|
var errBadTemplate = errors.New("template error")
|
||||||
|
|
||||||
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
@@ -275,11 +276,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.Response = sb.String()
|
r.Response = sb.String()
|
||||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
|
||||||
r.ToolCalls = toolCalls
|
|
||||||
r.Response = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(http.StatusOK, r)
|
c.JSON(http.StatusOK, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -613,7 +609,9 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
quantization := cmp.Or(r.Quantize, r.Quantization)
|
quantization := cmp.Or(r.Quantize, r.Quantization)
|
||||||
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
|
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); errors.Is(err, errBadTemplate) {
|
||||||
|
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
|
||||||
|
} else if err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -1201,11 +1199,15 @@ func waitForStream(c *gin.Context, ch chan interface{}) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case gin.H:
|
case gin.H:
|
||||||
|
status, ok := r["status"].(int)
|
||||||
|
if !ok {
|
||||||
|
status = http.StatusInternalServerError
|
||||||
|
}
|
||||||
if errorMsg, ok := r["error"].(string); ok {
|
if errorMsg, ok := r["error"].(string); ok {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
c.JSON(status, gin.H{"error": errorMsg})
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
|
c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@@ -1295,7 +1297,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
caps := []Capability{CapabilityCompletion}
|
caps := []Capability{CapabilityCompletion}
|
||||||
if req.Tools != nil {
|
if len(req.Tools) > 0 {
|
||||||
caps = append(caps, CapabilityTools)
|
caps = append(caps, CapabilityTools)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1390,9 +1392,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
resp.Message.Content = sb.String()
|
resp.Message.Content = sb.String()
|
||||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
|
||||||
resp.Message.ToolCalls = toolCalls
|
if len(req.Tools) > 0 {
|
||||||
resp.Message.Content = ""
|
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||||
|
resp.Message.ToolCalls = toolCalls
|
||||||
|
resp.Message.Content = ""
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
|
|||||||
@@ -491,6 +491,42 @@ func TestCreateTemplateSystem(t *testing.T) {
|
|||||||
if string(system) != "Say bye!" {
|
if string(system) != "Say bye!" {
|
||||||
t.Errorf("expected \"Say bye!\", actual %s", system)
|
t.Errorf("expected \"Say bye!\", actual %s", system)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Run("incomplete template", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status code 400, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("template with unclosed if", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status code 400, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("template with undefined function", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status code 400, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateLicenses(t *testing.T) {
|
func TestCreateLicenses(t *testing.T) {
|
||||||
@@ -563,9 +599,10 @@ func TestCreateDetectTemplate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-0d79f567714c62c048378f2107fb332dabee0135d080c302d884317da9433cc5"),
|
||||||
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
|
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
|
||||||
filepath.Join(p, "blobs", "sha256-c608dc615584cd20d9d830363dabf8a4783ae5d34245c3d8c115edb3bc7b28e4"),
|
filepath.Join(p, "blobs", "sha256-c608dc615584cd20d9d830363dabf8a4783ae5d34245c3d8c115edb3bc7b28e4"),
|
||||||
filepath.Join(p, "blobs", "sha256-f836ee110db21567f826332e4cedd746c06d10664fd5a9ea3659e3683a944510"),
|
filepath.Join(p, "blobs", "sha256-ea34c57ba5b78b740aafe2aeb74dc6507fc3ad14170b64c26a04fb9e36c88d75"),
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -73,8 +73,8 @@ func TestGenerateChat(t *testing.T) {
|
|||||||
getCpuFn: gpu.GetCPUInfo,
|
getCpuFn: gpu.GetCPUInfo,
|
||||||
reschedDelay: 250 * time.Millisecond,
|
reschedDelay: 250 * time.Millisecond,
|
||||||
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
||||||
// add 10ms delay to simulate loading
|
// add small delay to simulate loading
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(time.Millisecond)
|
||||||
req.successCh <- &runnerRef{
|
req.successCh <- &runnerRef{
|
||||||
llama: &mock,
|
llama: &mock,
|
||||||
}
|
}
|
||||||
@@ -371,6 +371,8 @@ func TestGenerate(t *testing.T) {
|
|||||||
getCpuFn: gpu.GetCPUInfo,
|
getCpuFn: gpu.GetCPUInfo,
|
||||||
reschedDelay: 250 * time.Millisecond,
|
reschedDelay: 250 * time.Millisecond,
|
||||||
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
||||||
|
// add small delay to simulate loading
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
req.successCh <- &runnerRef{
|
req.successCh <- &runnerRef{
|
||||||
llama: &mock,
|
llama: &mock,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -132,6 +132,8 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||||||
if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 {
|
if len(pending.model.ProjectorPaths) > 0 && numParallel != 1 {
|
||||||
numParallel = 1
|
numParallel = 1
|
||||||
slog.Warn("multimodal models don't support parallel requests yet")
|
slog.Warn("multimodal models don't support parallel requests yet")
|
||||||
|
} else if strings.Contains(pending.model.Config.ModelFamily, "bert") {
|
||||||
|
numParallel = runtime.NumCPU()
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ func TestLoad(t *testing.T) {
|
|||||||
require.Len(t, s.expiredCh, 1)
|
require.Len(t, s.expiredCh, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
type bundle struct {
|
type reqBundle struct {
|
||||||
ctx context.Context //nolint:containedctx
|
ctx context.Context //nolint:containedctx
|
||||||
ctxDone func()
|
ctxDone func()
|
||||||
srv *mockLlm
|
srv *mockLlm
|
||||||
@@ -102,13 +102,13 @@ type bundle struct {
|
|||||||
ggml *llm.GGML
|
ggml *llm.GGML
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
func (scenario *reqBundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||||
return scenario.srv, nil
|
return scenario.srv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle {
|
func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64, duration *api.Duration) *reqBundle {
|
||||||
scenario := &bundle{}
|
b := &reqBundle{}
|
||||||
scenario.ctx, scenario.ctxDone = context.WithCancel(ctx)
|
b.ctx, b.ctxDone = context.WithCancel(ctx)
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
f, err := os.CreateTemp(t.TempDir(), modelName)
|
f, err := os.CreateTemp(t.TempDir(), modelName)
|
||||||
@@ -135,124 +135,154 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
|
|||||||
|
|
||||||
fname := f.Name()
|
fname := f.Name()
|
||||||
model := &Model{Name: modelName, ModelPath: fname}
|
model := &Model{Name: modelName, ModelPath: fname}
|
||||||
scenario.ggml, err = llm.LoadModel(model.ModelPath, 0)
|
b.ggml, err = llm.LoadModel(model.ModelPath, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
scenario.req = &LlmRequest{
|
if duration == nil {
|
||||||
ctx: scenario.ctx,
|
duration = &api.Duration{Duration: 5 * time.Millisecond}
|
||||||
|
}
|
||||||
|
b.req = &LlmRequest{
|
||||||
|
ctx: b.ctx,
|
||||||
model: model,
|
model: model,
|
||||||
opts: api.DefaultOptions(),
|
opts: api.DefaultOptions(),
|
||||||
sessionDuration: &api.Duration{Duration: 5 * time.Millisecond},
|
sessionDuration: duration,
|
||||||
successCh: make(chan *runnerRef, 1),
|
successCh: make(chan *runnerRef, 1),
|
||||||
errCh: make(chan error, 1),
|
errCh: make(chan error, 1),
|
||||||
}
|
}
|
||||||
scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
|
b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
|
||||||
return scenario
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRequests(t *testing.T) {
|
func getGpuFn() gpu.GpuInfoList {
|
||||||
ctx, done := context.WithTimeout(context.Background(), 10*time.Second)
|
g := gpu.GpuInfo{Library: "metal"}
|
||||||
|
g.TotalMemory = 24 * format.GigaByte
|
||||||
|
g.FreeMemory = 12 * format.GigaByte
|
||||||
|
return []gpu.GpuInfo{g}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCpuFn() gpu.GpuInfoList {
|
||||||
|
g := gpu.GpuInfo{Library: "cpu"}
|
||||||
|
g.TotalMemory = 32 * format.GigaByte
|
||||||
|
g.FreeMemory = 26 * format.GigaByte
|
||||||
|
return []gpu.GpuInfo{g}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestsSameModelSameRequest(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
defer done()
|
defer done()
|
||||||
|
|
||||||
// Same model, same request
|
|
||||||
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
|
|
||||||
scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
|
|
||||||
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
|
|
||||||
scenario1b.req.model = scenario1a.req.model
|
|
||||||
scenario1b.ggml = scenario1a.ggml
|
|
||||||
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
|
|
||||||
|
|
||||||
// simple reload of same model
|
|
||||||
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
|
|
||||||
tmpModel := *scenario1a.req.model
|
|
||||||
scenario2a.req.model = &tmpModel
|
|
||||||
scenario2a.ggml = scenario1a.ggml
|
|
||||||
scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
|
|
||||||
|
|
||||||
// Multiple loaded models
|
|
||||||
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
|
|
||||||
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
|
|
||||||
scenario3c := newScenario(t, ctx, "ollama-model-4a", 30)
|
|
||||||
scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed
|
|
||||||
scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
|
|
||||||
|
|
||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
s.getGpuFn = func() gpu.GpuInfoList {
|
s.getGpuFn = getGpuFn
|
||||||
g := gpu.GpuInfo{Library: "metal"}
|
s.getCpuFn = getCpuFn
|
||||||
g.TotalMemory = 24 * format.GigaByte
|
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
|
||||||
g.FreeMemory = 12 * format.GigaByte
|
b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0})
|
||||||
return []gpu.GpuInfo{g}
|
b.req.model = a.req.model
|
||||||
}
|
b.ggml = a.ggml
|
||||||
s.getCpuFn = func() gpu.GpuInfoList {
|
|
||||||
g := gpu.GpuInfo{Library: "cpu"}
|
s.newServerFn = a.newServer
|
||||||
g.TotalMemory = 32 * format.GigaByte
|
slog.Info("a")
|
||||||
g.FreeMemory = 26 * format.GigaByte
|
s.pendingReqCh <- a.req
|
||||||
return []gpu.GpuInfo{g}
|
|
||||||
}
|
|
||||||
s.newServerFn = scenario1a.newServer
|
|
||||||
slog.Info("scenario1a")
|
|
||||||
s.pendingReqCh <- scenario1a.req
|
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
s.Run(ctx)
|
s.Run(ctx)
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario1a.req.successCh:
|
case resp := <-a.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario1a.srv)
|
require.Equal(t, resp.llama, a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario1a.req.errCh)
|
require.Empty(t, a.req.errCh)
|
||||||
case err := <-scenario1a.req.errCh:
|
case err := <-a.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same runner as first request due to not needing a reload
|
// Same runner as first request due to not needing a reload
|
||||||
s.newServerFn = scenario1b.newServer
|
s.newServerFn = b.newServer
|
||||||
slog.Info("scenario1b")
|
slog.Info("b")
|
||||||
s.pendingReqCh <- scenario1b.req
|
s.pendingReqCh <- b.req
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario1b.req.successCh:
|
case resp := <-b.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario1a.srv)
|
require.Equal(t, resp.llama, a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario1b.req.errCh)
|
require.Empty(t, b.req.errCh)
|
||||||
case err := <-scenario1b.req.errCh:
|
case err := <-b.req.errCh:
|
||||||
|
t.Fatal(err.Error())
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Fatal("timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestsSimpleReloadSameModel(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer done()
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
s.getGpuFn = getGpuFn
|
||||||
|
s.getCpuFn = getCpuFn
|
||||||
|
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
|
||||||
|
b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond})
|
||||||
|
tmpModel := *a.req.model
|
||||||
|
b.req.model = &tmpModel
|
||||||
|
b.ggml = a.ggml
|
||||||
|
|
||||||
|
s.newServerFn = a.newServer
|
||||||
|
slog.Info("a")
|
||||||
|
s.pendingReqCh <- a.req
|
||||||
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
|
s.Run(ctx)
|
||||||
|
select {
|
||||||
|
case resp := <-a.req.successCh:
|
||||||
|
require.Equal(t, resp.llama, a.srv)
|
||||||
|
require.Empty(t, s.pendingReqCh)
|
||||||
|
require.Empty(t, a.req.errCh)
|
||||||
|
case err := <-a.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trigger a reload
|
// Trigger a reload
|
||||||
s.newServerFn = scenario2a.newServer
|
s.newServerFn = b.newServer
|
||||||
scenario2a.req.model.AdapterPaths = []string{"new"}
|
b.req.model.AdapterPaths = []string{"new"}
|
||||||
slog.Info("scenario2a")
|
slog.Info("b")
|
||||||
s.pendingReqCh <- scenario2a.req
|
s.pendingReqCh <- b.req
|
||||||
// finish first two requests, so model can reload
|
// finish first two requests, so model can reload
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
scenario1a.ctxDone()
|
a.ctxDone()
|
||||||
scenario1b.ctxDone()
|
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario2a.req.successCh:
|
case resp := <-b.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario2a.srv)
|
require.Equal(t, resp.llama, b.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario2a.req.errCh)
|
require.Empty(t, b.req.errCh)
|
||||||
case err := <-scenario2a.req.errCh:
|
case err := <-b.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestsMultipleLoadedModels(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer done()
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
s.getGpuFn = getGpuFn
|
||||||
|
s.getCpuFn = getCpuFn
|
||||||
|
|
||||||
|
// Multiple loaded models
|
||||||
|
a := newScenarioRequest(t, ctx, "ollama-model-3a", 1*format.GigaByte, nil)
|
||||||
|
b := newScenarioRequest(t, ctx, "ollama-model-3b", 24*format.GigaByte, nil)
|
||||||
|
c := newScenarioRequest(t, ctx, "ollama-model-4a", 30, nil)
|
||||||
|
c.req.opts.NumGPU = 0 // CPU load, will be allowed
|
||||||
|
d := newScenarioRequest(t, ctx, "ollama-model-3c", 30, nil) // Needs prior unloaded
|
||||||
|
|
||||||
envconfig.MaxRunners = 1
|
envconfig.MaxRunners = 1
|
||||||
s.newServerFn = scenario3a.newServer
|
s.newServerFn = a.newServer
|
||||||
slog.Info("scenario3a")
|
slog.Info("a")
|
||||||
s.pendingReqCh <- scenario3a.req
|
s.pendingReqCh <- a.req
|
||||||
// finish prior request, so new model can load
|
s.Run(ctx)
|
||||||
time.Sleep(1 * time.Millisecond)
|
|
||||||
scenario2a.ctxDone()
|
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario3a.req.successCh:
|
case resp := <-a.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario3a.srv)
|
require.Equal(t, resp.llama, a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario3a.req.errCh)
|
require.Empty(t, a.req.errCh)
|
||||||
case err := <-scenario3a.req.errCh:
|
case err := <-a.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
@@ -262,15 +292,15 @@ func TestRequests(t *testing.T) {
|
|||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
envconfig.MaxRunners = 0
|
envconfig.MaxRunners = 0
|
||||||
s.newServerFn = scenario3b.newServer
|
s.newServerFn = b.newServer
|
||||||
slog.Info("scenario3b")
|
slog.Info("b")
|
||||||
s.pendingReqCh <- scenario3b.req
|
s.pendingReqCh <- b.req
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario3b.req.successCh:
|
case resp := <-b.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario3b.srv)
|
require.Equal(t, resp.llama, b.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario3b.req.errCh)
|
require.Empty(t, b.req.errCh)
|
||||||
case err := <-scenario3b.req.errCh:
|
case err := <-b.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
@@ -280,15 +310,15 @@ func TestRequests(t *testing.T) {
|
|||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
// This is a CPU load with NumGPU = 0 so it should load
|
// This is a CPU load with NumGPU = 0 so it should load
|
||||||
s.newServerFn = scenario3c.newServer
|
s.newServerFn = c.newServer
|
||||||
slog.Info("scenario3c")
|
slog.Info("c")
|
||||||
s.pendingReqCh <- scenario3c.req
|
s.pendingReqCh <- c.req
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario3c.req.successCh:
|
case resp := <-c.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario3c.srv)
|
require.Equal(t, resp.llama, c.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario3c.req.errCh)
|
require.Empty(t, c.req.errCh)
|
||||||
case err := <-scenario3c.req.errCh:
|
case err := <-c.req.errCh:
|
||||||
t.Fatal(err.Error())
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
@@ -298,25 +328,25 @@ func TestRequests(t *testing.T) {
|
|||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
// Try to load a model that wont fit
|
// Try to load a model that wont fit
|
||||||
s.newServerFn = scenario3d.newServer
|
s.newServerFn = d.newServer
|
||||||
slog.Info("scenario3d")
|
slog.Info("d")
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 3)
|
require.Len(t, s.loaded, 3)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
|
a.ctxDone() // Won't help since this one isn't big enough to make room
|
||||||
time.Sleep(2 * time.Millisecond)
|
time.Sleep(2 * time.Millisecond)
|
||||||
s.pendingReqCh <- scenario3d.req
|
s.pendingReqCh <- d.req
|
||||||
// finish prior request, so new model can load
|
// finish prior request, so new model can load
|
||||||
time.Sleep(6 * time.Millisecond)
|
time.Sleep(6 * time.Millisecond)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 2)
|
require.Len(t, s.loaded, 2)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
scenario3b.ctxDone()
|
b.ctxDone()
|
||||||
select {
|
select {
|
||||||
case resp := <-scenario3d.req.successCh:
|
case resp := <-d.req.successCh:
|
||||||
require.Equal(t, resp.llama, scenario3d.srv)
|
require.Equal(t, resp.llama, d.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario3d.req.errCh)
|
require.Empty(t, d.req.errCh)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
@@ -329,26 +359,19 @@ func TestGetRunner(t *testing.T) {
|
|||||||
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
defer done()
|
defer done()
|
||||||
|
|
||||||
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
|
a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond})
|
||||||
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
|
b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond})
|
||||||
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
|
c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond})
|
||||||
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
|
|
||||||
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
|
|
||||||
scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
|
|
||||||
envconfig.MaxQueuedRequests = 1
|
envconfig.MaxQueuedRequests = 1
|
||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
s.getGpuFn = func() gpu.GpuInfoList {
|
s.getGpuFn = getGpuFn
|
||||||
g := gpu.GpuInfo{Library: "metal"}
|
s.getCpuFn = getCpuFn
|
||||||
g.TotalMemory = 24 * format.GigaByte
|
s.newServerFn = a.newServer
|
||||||
g.FreeMemory = 12 * format.GigaByte
|
slog.Info("a")
|
||||||
return []gpu.GpuInfo{g}
|
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
|
||||||
}
|
|
||||||
s.newServerFn = scenario1a.newServer
|
|
||||||
slog.Info("scenario1a")
|
|
||||||
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
|
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
slog.Info("scenario1b")
|
slog.Info("b")
|
||||||
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
|
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration)
|
||||||
require.Len(t, s.pendingReqCh, 1)
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
require.Empty(t, successCh1b)
|
require.Empty(t, successCh1b)
|
||||||
require.Len(t, errCh1b, 1)
|
require.Len(t, errCh1b, 1)
|
||||||
@@ -357,22 +380,24 @@ func TestGetRunner(t *testing.T) {
|
|||||||
s.Run(ctx)
|
s.Run(ctx)
|
||||||
select {
|
select {
|
||||||
case resp := <-successCh1a:
|
case resp := <-successCh1a:
|
||||||
require.Equal(t, resp.llama, scenario1a.srv)
|
require.Equal(t, resp.llama, a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, errCh1a)
|
require.Empty(t, errCh1a)
|
||||||
|
case err := <-errCh1a:
|
||||||
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
scenario1a.ctxDone()
|
a.ctxDone() // Set "a" model to idle so it can unload
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Len(t, s.loaded, 1)
|
require.Len(t, s.loaded, 1)
|
||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
scenario1c.req.model.ModelPath = "bad path"
|
c.req.model.ModelPath = "bad path"
|
||||||
slog.Info("scenario1c")
|
slog.Info("c")
|
||||||
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
|
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration)
|
||||||
// Starts in pending channel, then should be quickly processsed to return an error
|
// Starts in pending channel, then should be quickly processsed to return an error
|
||||||
time.Sleep(5 * time.Millisecond)
|
time.Sleep(20 * time.Millisecond) // Long enough for the "a" model to expire and unload
|
||||||
require.Empty(t, successCh1c)
|
require.Empty(t, successCh1c)
|
||||||
s.loadedMu.Lock()
|
s.loadedMu.Lock()
|
||||||
require.Empty(t, s.loaded)
|
require.Empty(t, s.loaded)
|
||||||
@@ -380,7 +405,7 @@ func TestGetRunner(t *testing.T) {
|
|||||||
require.Len(t, errCh1c, 1)
|
require.Len(t, errCh1c, 1)
|
||||||
err = <-errCh1c
|
err = <-errCh1c
|
||||||
require.Contains(t, err.Error(), "bad path")
|
require.Contains(t, err.Error(), "bad path")
|
||||||
scenario1b.ctxDone()
|
b.ctxDone()
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO - add one scenario that triggers the bogus finished event with positive ref count
|
// TODO - add one scenario that triggers the bogus finished event with positive ref count
|
||||||
@@ -389,7 +414,7 @@ func TestPrematureExpired(t *testing.T) {
|
|||||||
defer done()
|
defer done()
|
||||||
|
|
||||||
// Same model, same request
|
// Same model, same request
|
||||||
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
|
scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil)
|
||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
s.getGpuFn = func() gpu.GpuInfoList {
|
s.getGpuFn = func() gpu.GpuInfoList {
|
||||||
g := gpu.GpuInfo{Library: "metal"}
|
g := gpu.GpuInfo{Library: "metal"}
|
||||||
@@ -411,6 +436,8 @@ func TestPrematureExpired(t *testing.T) {
|
|||||||
s.loadedMu.Unlock()
|
s.loadedMu.Unlock()
|
||||||
slog.Info("sending premature expired event now")
|
slog.Info("sending premature expired event now")
|
||||||
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
|
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
|
||||||
|
case err := <-errCh1a:
|
||||||
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
@@ -446,6 +473,8 @@ func TestUseLoadedRunner(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case success := <-req.successCh:
|
case success := <-req.successCh:
|
||||||
require.Equal(t, r1, success)
|
require.Equal(t, r1, success)
|
||||||
|
case err := <-req.errCh:
|
||||||
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
@@ -625,8 +654,7 @@ func TestAlreadyCanceled(t *testing.T) {
|
|||||||
defer done()
|
defer done()
|
||||||
dctx, done2 := context.WithCancel(ctx)
|
dctx, done2 := context.WithCancel(ctx)
|
||||||
done2()
|
done2()
|
||||||
scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
|
scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0})
|
||||||
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
|
|
||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
slog.Info("scenario1a")
|
slog.Info("scenario1a")
|
||||||
s.pendingReqCh <- scenario1a.req
|
s.pendingReqCh <- scenario1a.req
|
||||||
|
|||||||
2
server/testdata/tools/command-r-plus.gotmpl
vendored
2
server/testdata/tools/command-r-plus.gotmpl
vendored
@@ -46,7 +46,7 @@ Action: ```json
|
|||||||
{{- range .ToolCalls }}
|
{{- range .ToolCalls }}
|
||||||
{
|
{
|
||||||
"tool_name": "{{ .Function.Name }}",
|
"tool_name": "{{ .Function.Name }}",
|
||||||
"parameters": {{ json .Function.Arguments }}
|
"parameters": {{ .Function.Arguments }}
|
||||||
}
|
}
|
||||||
{{- end }}
|
{{- end }}
|
||||||
]```
|
]```
|
||||||
|
|||||||
4
server/testdata/tools/firefunction.gotmpl
vendored
4
server/testdata/tools/firefunction.gotmpl
vendored
@@ -17,7 +17,7 @@ If you decide to call functions:
|
|||||||
|
|
||||||
Available functions as JSON spec:
|
Available functions as JSON spec:
|
||||||
{{- if .Tools }}
|
{{- if .Tools }}
|
||||||
{{ json .Tools }}
|
{{ .Tools }}
|
||||||
{{- end }}<|eot_id|>
|
{{- end }}<|eot_id|>
|
||||||
{{- end }}
|
{{- end }}
|
||||||
{{- range .Messages }}<|start_header_id|>
|
{{- range .Messages }}<|start_header_id|>
|
||||||
@@ -25,7 +25,7 @@ Available functions as JSON spec:
|
|||||||
{{- end }}<|end_header_id|>
|
{{- end }}<|end_header_id|>
|
||||||
{{- if .Content }}{{ .Content }}
|
{{- if .Content }}{{ .Content }}
|
||||||
{{- else if .ToolCalls }} functools[
|
{{- else if .ToolCalls }} functools[
|
||||||
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}{{ "}" }}
|
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }}
|
||||||
{{- end }}]
|
{{- end }}]
|
||||||
{{- end }}<|eot_id|>
|
{{- end }}<|eot_id|>
|
||||||
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
||||||
43
server/testdata/tools/llama3-groq-tool-use.gotmpl
vendored
Normal file
43
server/testdata/tools/llama3-groq-tool-use.gotmpl
vendored
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
{{- if .Messages }}
|
||||||
|
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .System }}
|
||||||
|
{{- if .Tools }} You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||||
|
<tool_call>
|
||||||
|
{"name": <function-name>,"arguments": <args-dict>}
|
||||||
|
</tool_call>
|
||||||
|
|
||||||
|
Here are the available tools:
|
||||||
|
<tools>
|
||||||
|
{{- range .Tools }} {{ .Function }}
|
||||||
|
{{- end }} </tools>
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}<|eot_id|>
|
||||||
|
{{- range .Messages }}
|
||||||
|
{{- if ne .Role "system" }}<|start_header_id|>{{ .Role }}<|end_header_id|>
|
||||||
|
|
||||||
|
{{ if eq .Role "user" }}{{ .Content }}
|
||||||
|
{{- else if eq .Role "assistant" }}
|
||||||
|
{{- if .Content }}{{ .Content }}
|
||||||
|
{{- else if .ToolCalls }}<tool_call>
|
||||||
|
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||||
|
{{- end }}
|
||||||
|
</tool_call>
|
||||||
|
{{- end }}
|
||||||
|
{{- else if eq .Role "tool" }}<tool_response>
|
||||||
|
{{ .Content }}
|
||||||
|
</tool_response>
|
||||||
|
{{- end }}<|eot_id|>
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
{{ else }}
|
||||||
|
{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
{{ end }}{{ .Response }}
|
||||||
|
{{- if .Response }}<|eot_id|>
|
||||||
|
{{- end }}
|
||||||
24
server/testdata/tools/llama3-groq-tool-use.out
vendored
Normal file
24
server/testdata/tools/llama3-groq-tool-use.out
vendored
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
You are a knowledgable assistant. You can answer questions and perform tasks. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||||
|
<tool_call>
|
||||||
|
{"name": <function-name>,"arguments": <args-dict>}
|
||||||
|
</tool_call>
|
||||||
|
|
||||||
|
Here are the available tools:
|
||||||
|
<tools> {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}} </tools><|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
<tool_call>
|
||||||
|
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
|
||||||
|
</tool_call><|eot_id|><|start_header_id|>tool<|end_header_id|>
|
||||||
|
|
||||||
|
<tool_response>
|
||||||
|
22
|
||||||
|
</tool_response><|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
4
server/testdata/tools/mistral.gotmpl
vendored
4
server/testdata/tools/mistral.gotmpl
vendored
@@ -1,13 +1,13 @@
|
|||||||
{{- range $index, $_ := .Messages }}
|
{{- range $index, $_ := .Messages }}
|
||||||
{{- if eq .Role "user" }}
|
{{- if eq .Role "user" }}
|
||||||
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
|
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]
|
||||||
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
|
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
|
||||||
|
|
||||||
{{ end }}{{ .Content }}[/INST]
|
{{ end }}{{ .Content }}[/INST]
|
||||||
{{- else if eq .Role "assistant" }}
|
{{- else if eq .Role "assistant" }}
|
||||||
{{- if .Content }} {{ .Content }}</s>
|
{{- if .Content }} {{ .Content }}</s>
|
||||||
{{- else if .ToolCalls }}[TOOL_CALLS] [
|
{{- else if .ToolCalls }}[TOOL_CALLS] [
|
||||||
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
|
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||||
{{- end }}]</s>
|
{{- end }}]</s>
|
||||||
{{- end }}
|
{{- end }}
|
||||||
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
|
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
|
||||||
|
|||||||
45
server/testdata/tools/xlam.gotmpl
vendored
Normal file
45
server/testdata/tools/xlam.gotmpl
vendored
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
{{- if .System }}{{ .System }}
|
||||||
|
{{ end }}
|
||||||
|
{{- range $i, $_ := .Messages }}
|
||||||
|
{{- if eq .Role "user" }}### Instruction:
|
||||||
|
{{- if and $.Tools (le (len (slice $.Messages $i)) 2) }}
|
||||||
|
[BEGIN OF TASK INSTRUCTION]
|
||||||
|
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||||
|
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||||
|
If none of the functions can be used, point it out and refuse to answer.
|
||||||
|
If the given question lacks the parameters required by the function, also point it out.
|
||||||
|
[END OF TASK INSTRUCTION]
|
||||||
|
|
||||||
|
[BEGIN OF AVAILABLE TOOLS]
|
||||||
|
{{ $.Tools }}
|
||||||
|
[END OF AVAILABLE TOOLS]
|
||||||
|
|
||||||
|
[BEGIN OF FORMAT INSTRUCTION]
|
||||||
|
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
|
||||||
|
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"tool_calls": [
|
||||||
|
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
|
||||||
|
... (more tool calls as required)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
[END OF FORMAT INSTRUCTION]
|
||||||
|
|
||||||
|
[BEGIN OF QUERY]
|
||||||
|
{{ .Content }}
|
||||||
|
[END OF QUERY]
|
||||||
|
|
||||||
|
|
||||||
|
{{ else }}
|
||||||
|
{{ .Content }}
|
||||||
|
{{ end }}
|
||||||
|
{{- else if .ToolCalls }}### Response:
|
||||||
|
{"tool_calls": [{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{ end }}]}
|
||||||
|
<|EOT|>
|
||||||
|
{{ else if eq .Role "assistant" }}### Response:
|
||||||
|
{{ .Content }}
|
||||||
|
<|EOT|>
|
||||||
|
{{ end }}
|
||||||
|
{{- end }}### Response:
|
||||||
40
server/testdata/tools/xlam.out
vendored
Normal file
40
server/testdata/tools/xlam.out
vendored
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
You are a knowledgable assistant. You can answer questions and perform tasks.
|
||||||
|
### Instruction:
|
||||||
|
What's the weather like today in Paris?
|
||||||
|
### Response:
|
||||||
|
{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]}
|
||||||
|
<|EOT|>
|
||||||
|
### Response:
|
||||||
|
The current temperature in Paris, France is 22 degrees Celsius.
|
||||||
|
<|EOT|>
|
||||||
|
### Instruction:
|
||||||
|
[BEGIN OF TASK INSTRUCTION]
|
||||||
|
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||||
|
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||||
|
If none of the functions can be used, point it out and refuse to answer.
|
||||||
|
If the given question lacks the parameters required by the function, also point it out.
|
||||||
|
[END OF TASK INSTRUCTION]
|
||||||
|
|
||||||
|
[BEGIN OF AVAILABLE TOOLS]
|
||||||
|
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]
|
||||||
|
[END OF AVAILABLE TOOLS]
|
||||||
|
|
||||||
|
[BEGIN OF FORMAT INSTRUCTION]
|
||||||
|
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
|
||||||
|
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"tool_calls": [
|
||||||
|
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
|
||||||
|
... (more tool calls as required)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
[END OF FORMAT INSTRUCTION]
|
||||||
|
|
||||||
|
[BEGIN OF QUERY]
|
||||||
|
What's the weather like today in San Francisco and Toronto?
|
||||||
|
[END OF QUERY]
|
||||||
|
|
||||||
|
|
||||||
|
### Response:
|
||||||
8
template/alfred.json
Normal file
8
template/alfred.json
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"<start_system>",
|
||||||
|
"<end_message>",
|
||||||
|
"<start_user>",
|
||||||
|
"<start_assistant>"
|
||||||
|
]
|
||||||
|
}
|
||||||
6
template/alpaca.json
Normal file
6
template/alpaca.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"### Instruction:",
|
||||||
|
"### Response"
|
||||||
|
]
|
||||||
|
}
|
||||||
6
template/chatml.json
Normal file
6
template/chatml.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"<|im_start|>",
|
||||||
|
"<|im_end|>"
|
||||||
|
]
|
||||||
|
}
|
||||||
8
template/chatqa.json
Normal file
8
template/chatqa.json
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"System:",
|
||||||
|
"User:",
|
||||||
|
"Assistant:",
|
||||||
|
"<|begin_of_text|>"
|
||||||
|
]
|
||||||
|
}
|
||||||
7
template/codellama-70b-instruct.json
Normal file
7
template/codellama-70b-instruct.json
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"Source:",
|
||||||
|
"Destination:",
|
||||||
|
"<step>"
|
||||||
|
]
|
||||||
|
}
|
||||||
6
template/falcon-instruct.json
Normal file
6
template/falcon-instruct.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"User:",
|
||||||
|
"Assistant:"
|
||||||
|
]
|
||||||
|
}
|
||||||
6
template/gemma-instruct.json
Normal file
6
template/gemma-instruct.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"<start_of_turn>",
|
||||||
|
"<end_of_turn>"
|
||||||
|
]
|
||||||
|
}
|
||||||
7
template/granite-instruct.json
Normal file
7
template/granite-instruct.json
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"System:",
|
||||||
|
"Question:",
|
||||||
|
"Answer:"
|
||||||
|
]
|
||||||
|
}
|
||||||
8
template/llama2-chat.json
Normal file
8
template/llama2-chat.json
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"[INST]",
|
||||||
|
"[/INST]",
|
||||||
|
"<<SYS>>",
|
||||||
|
"<</SYS>>"
|
||||||
|
]
|
||||||
|
}
|
||||||
7
template/llama3-instruct.json
Normal file
7
template/llama3-instruct.json
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"<|start_header_id|>",
|
||||||
|
"<|end_header_id|>",
|
||||||
|
"<|eot_id|>"
|
||||||
|
]
|
||||||
|
}
|
||||||
6
template/magicoder.json
Normal file
6
template/magicoder.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"@@ Instruction",
|
||||||
|
"@@ Response"
|
||||||
|
]
|
||||||
|
}
|
||||||
6
template/mistral-instruct.json
Normal file
6
template/mistral-instruct.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"<|im_start|>",
|
||||||
|
"<|im_end|>"
|
||||||
|
]
|
||||||
|
}
|
||||||
5
template/openchat.json
Normal file
5
template/openchat.json
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"<|end_of_turn|>"
|
||||||
|
]
|
||||||
|
}
|
||||||
8
template/phi-3.json
Normal file
8
template/phi-3.json
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"<|end|>",
|
||||||
|
"<|system|>",
|
||||||
|
"<|user|>",
|
||||||
|
"<|assistant|>"
|
||||||
|
]
|
||||||
|
}
|
||||||
7
template/solar-instruct.json
Normal file
7
template/solar-instruct.json
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"### System:",
|
||||||
|
"### User:",
|
||||||
|
"### Assistant"
|
||||||
|
]
|
||||||
|
}
|
||||||
7
template/starcoder2-instruct.json
Normal file
7
template/starcoder2-instruct.json
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"### Instruction",
|
||||||
|
"### Response",
|
||||||
|
"<|endoftext|>"
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
var indexBytes []byte
|
var indexBytes []byte
|
||||||
|
|
||||||
//go:embed *.gotmpl
|
//go:embed *.gotmpl
|
||||||
|
//go:embed *.json
|
||||||
var templatesFS embed.FS
|
var templatesFS embed.FS
|
||||||
|
|
||||||
var templatesOnce = sync.OnceValues(func() ([]*named, error) {
|
var templatesOnce = sync.OnceValues(func() ([]*named, error) {
|
||||||
@@ -39,6 +40,15 @@ var templatesOnce = sync.OnceValues(func() ([]*named, error) {
|
|||||||
|
|
||||||
// normalize line endings
|
// normalize line endings
|
||||||
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
|
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
|
||||||
|
|
||||||
|
params, err := templatesFS.ReadFile(t.Name + ".json")
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(params, &t.Parameters); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return templates, nil
|
return templates, nil
|
||||||
@@ -48,6 +58,10 @@ type named struct {
|
|||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Template string `json:"template"`
|
Template string `json:"template"`
|
||||||
Bytes []byte
|
Bytes []byte
|
||||||
|
|
||||||
|
Parameters *struct {
|
||||||
|
Stop []string `json:"stop"`
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t named) Reader() io.Reader {
|
func (t named) Reader() io.Reader {
|
||||||
@@ -150,9 +164,9 @@ func (t *Template) Vars() []string {
|
|||||||
|
|
||||||
type Values struct {
|
type Values struct {
|
||||||
Messages []api.Message
|
Messages []api.Message
|
||||||
Tools []api.Tool
|
api.Tools
|
||||||
Prompt string
|
Prompt string
|
||||||
Suffix string
|
Suffix string
|
||||||
|
|
||||||
// forceLegacy is a flag used to test compatibility with legacy templates
|
// forceLegacy is a flag used to test compatibility with legacy templates
|
||||||
forceLegacy bool
|
forceLegacy bool
|
||||||
@@ -217,6 +231,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||||||
"System": system,
|
"System": system,
|
||||||
"Messages": messages,
|
"Messages": messages,
|
||||||
"Tools": v.Tools,
|
"Tools": v.Tools,
|
||||||
|
"Response": "",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -263,6 +278,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||||||
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
|
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
|
||||||
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
|
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
|
||||||
cut = true
|
cut = true
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return cut
|
return cut
|
||||||
@@ -270,8 +286,9 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||||||
|
|
||||||
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
|
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
|
||||||
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
|
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
|
||||||
"System": system,
|
"System": system,
|
||||||
"Prompt": prompt,
|
"Prompt": prompt,
|
||||||
|
"Response": response,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -260,6 +260,26 @@ func TestExecuteWithMessages(t *testing.T) {
|
|||||||
|
|
||||||
Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
|
Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"mistral assistant",
|
||||||
|
[]template{
|
||||||
|
{"no response", `[INST] {{ .Prompt }}[/INST] `},
|
||||||
|
{"response", `[INST] {{ .Prompt }}[/INST] {{ .Response }}`},
|
||||||
|
{"messages", `
|
||||||
|
{{- range $i, $m := .Messages }}
|
||||||
|
{{- if eq .Role "user" }}[INST] {{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}{{ end }}
|
||||||
|
{{- end }}`},
|
||||||
|
},
|
||||||
|
Values{
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello friend!"},
|
||||||
|
{Role: "assistant", Content: "Hello human!"},
|
||||||
|
{Role: "user", Content: "What is your name?"},
|
||||||
|
{Role: "assistant", Content: "My name is Ollama and I"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] My name is Ollama and I`,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"chatml",
|
"chatml",
|
||||||
[]template{
|
[]template{
|
||||||
|
|||||||
6
template/vicuna.json
Normal file
6
template/vicuna.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"USER:",
|
||||||
|
"ASSISTANT:"
|
||||||
|
]
|
||||||
|
}
|
||||||
8
template/zephyr.json
Normal file
8
template/zephyr.json
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"stop": [
|
||||||
|
"<|system|>",
|
||||||
|
"</s>",
|
||||||
|
"<|user|>",
|
||||||
|
"<|assistant|>"
|
||||||
|
]
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user