Compare commits
147 Commits
mxyng/crea
...
royh-opena
Author | SHA1 | Date | |
---|---|---|---|
![]() |
568416ba17 | ||
![]() |
80cba42ab2 | ||
![]() |
6477a7aca4 | ||
![]() |
51214ddef5 | ||
![]() |
b950d749a9 | ||
![]() |
3702ed7532 | ||
![]() |
6266603b17 | ||
![]() |
499e87c9ba | ||
![]() |
cd0853f2d5 | ||
![]() |
d290e87513 | ||
![]() |
97c20ede33 | ||
![]() |
5a83f79afd | ||
![]() |
987dbab0b0 | ||
![]() |
a8388beb94 | ||
![]() |
5afbb60fc4 | ||
![]() |
4cb5d7decc | ||
![]() |
8eac50dd4f | ||
![]() |
4a565cbf94 | ||
![]() |
64039df6d7 | ||
![]() |
7ac6d462ec | ||
![]() |
ef5136a745 | ||
![]() |
8288ec8824 | ||
![]() |
d02bbebb11 | ||
![]() |
224337b32f | ||
![]() |
9e35d9bbee | ||
![]() |
b9f5e16c80 | ||
![]() |
e9f7f36029 | ||
![]() |
057d31861e | ||
![]() |
f7ee012300 | ||
![]() |
1ed0aa8fea | ||
![]() |
ef98803d63 | ||
![]() |
02fea420e5 | ||
![]() |
22c5451fc2 | ||
![]() |
23ebbaa46e | ||
![]() |
9ac0a7a50b | ||
![]() |
e5c65a85df | ||
![]() |
33627331a3 | ||
![]() |
36c87c433b | ||
![]() |
179737feb7 | ||
![]() |
47353f5ee4 | ||
![]() |
10e768826c | ||
![]() |
5056bb9c01 | ||
![]() |
c4cf8ad559 | ||
![]() |
57ec6901eb | ||
![]() |
e64f9ebb44 | ||
![]() |
791650ddef | ||
![]() |
efbf41ed81 | ||
![]() |
cf15589851 | ||
![]() |
19753c18c0 | ||
![]() |
41be28096a | ||
![]() |
37a570f962 | ||
![]() |
5a739ff4cb | ||
![]() |
4e262eb2a8 | ||
![]() |
4cfcbc328f | ||
![]() |
79292ff3e0 | ||
![]() |
8ea500441d | ||
![]() |
b50c818623 | ||
![]() |
b99e750b62 | ||
![]() |
1f50356e8e | ||
![]() |
22c81f62ec | ||
![]() |
2d1e3c3229 | ||
![]() |
4918fae535 | ||
![]() |
0aff67877e | ||
![]() |
f6f759fc5f | ||
![]() |
9544a57ee4 | ||
![]() |
b51e3b63ac | ||
![]() |
6bbbc50f10 | ||
![]() |
9bbddc37a7 | ||
![]() |
e4ff73297d | ||
![]() |
b44320db13 | ||
![]() |
2644c4e682 | ||
![]() |
04cde43b2a | ||
![]() |
0bacb30007 | ||
![]() |
53da2c6965 | ||
![]() |
d8def1ff94 | ||
![]() |
571dc61955 | ||
![]() |
0e09c380fc | ||
![]() |
0ee87615c7 | ||
![]() |
f8241bfba3 | ||
![]() |
4607c70641 | ||
![]() |
c12f1c5b99 | ||
![]() |
a08f20d910 | ||
![]() |
6cea036027 | ||
![]() |
5796bfc401 | ||
![]() |
f1a379aa56 | ||
![]() |
9ae146993e | ||
![]() |
e0348d3fe8 | ||
![]() |
2cc854f8cb | ||
![]() |
5304b765b2 | ||
![]() |
fb6cbc02fb | ||
![]() |
4fd5f3526a | ||
![]() |
842f85f758 | ||
![]() |
9d30f9f8b3 | ||
![]() |
631cfd9e62 | ||
![]() |
326363b3a7 | ||
![]() |
ac7a842e55 | ||
![]() |
2c3fe1fd97 | ||
![]() |
269ed6e6a2 | ||
![]() |
78fb33dd07 | ||
![]() |
8f8e736b13 | ||
![]() |
d89454de80 | ||
![]() |
af28b94533 | ||
![]() |
e9188e971a | ||
![]() |
78eddfc068 | ||
![]() |
02c24d3d01 | ||
![]() |
52abc8acb7 | ||
![]() |
4d71c559b2 | ||
![]() |
0d16eb310e | ||
![]() |
8072e205ff | ||
![]() |
955f2a4e03 | ||
![]() |
105e36765d | ||
![]() |
3c75113e37 | ||
![]() |
ccd7785859 | ||
![]() |
3b5a4a77f3 | ||
![]() |
daed0634a9 | ||
![]() |
0d4dd707bc | ||
![]() |
0e982bc1f4 | ||
![]() |
6298f49816 | ||
![]() |
ef757da2c9 | ||
![]() |
e5352297d9 | ||
![]() |
65a5040e09 | ||
![]() |
d626b99b54 | ||
![]() |
fa7be5aab4 | ||
![]() |
dddb58a38b | ||
![]() |
400056e154 | ||
![]() |
d2f19024d0 | ||
![]() |
69c04eecc4 | ||
![]() |
996bb1b85e | ||
![]() |
422dcc3856 | ||
![]() |
020bd60ab2 | ||
![]() |
8e277b72bb | ||
![]() |
4f67b39d26 | ||
![]() |
2425281317 | ||
![]() |
0403e9860e | ||
![]() |
33a65e3ba3 | ||
![]() |
88bcd79bb9 | ||
![]() |
7e571f95f0 | ||
![]() |
da8e2a0447 | ||
![]() |
a30915bde1 | ||
![]() |
58e3fff311 | ||
![]() |
3f0b309ad4 | ||
![]() |
26e4e66faf | ||
![]() |
97c9e11768 | ||
![]() |
9bd00041fa | ||
![]() |
4e986a823c | ||
![]() |
02169f3e60 | ||
![]() |
784bf88b0d |
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
@@ -147,7 +147,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
write-host "downloading AMD HIP Installer"
|
write-host "downloading AMD HIP Installer"
|
||||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||||
write-host "Installing AMD HIP"
|
write-host "Installing AMD HIP"
|
||||||
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
||||||
write-host "Completed AMD HIP"
|
write-host "Completed AMD HIP"
|
||||||
|
6
.github/workflows/test.yaml
vendored
6
.github/workflows/test.yaml
vendored
@@ -58,6 +58,7 @@ jobs:
|
|||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
env:
|
env:
|
||||||
GOARCH: ${{ matrix.arch }}
|
GOARCH: ${{ matrix.arch }}
|
||||||
|
CGO_ENABLED: '1'
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
@@ -79,6 +80,7 @@ jobs:
|
|||||||
- run: go generate -x ./...
|
- run: go generate -x ./...
|
||||||
if: ${{ ! startsWith(matrix.os, 'windows-') }}
|
if: ${{ ! startsWith(matrix.os, 'windows-') }}
|
||||||
name: 'Unix Go Generate'
|
name: 'Unix Go Generate'
|
||||||
|
- run: go build .
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
|
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
|
||||||
@@ -124,7 +126,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
rocm-version:
|
rocm-version:
|
||||||
- '6.1.1'
|
- '6.1.2'
|
||||||
runs-on: linux
|
runs-on: linux
|
||||||
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
|
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
|
||||||
steps:
|
steps:
|
||||||
@@ -167,7 +169,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
write-host "downloading AMD HIP Installer"
|
write-host "downloading AMD HIP Installer"
|
||||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||||
write-host "Installing AMD HIP"
|
write-host "Installing AMD HIP"
|
||||||
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
||||||
write-host "Completed AMD HIP"
|
write-host "Completed AMD HIP"
|
||||||
|
@@ -2,7 +2,7 @@ ARG GOLANG_VERSION=1.22.1
|
|||||||
ARG CMAKE_VERSION=3.22.1
|
ARG CMAKE_VERSION=3.22.1
|
||||||
# this CUDA_VERSION corresponds with the one specified in docs/gpu.md
|
# this CUDA_VERSION corresponds with the one specified in docs/gpu.md
|
||||||
ARG CUDA_VERSION=11.3.1
|
ARG CUDA_VERSION=11.3.1
|
||||||
ARG ROCM_VERSION=6.1.1
|
ARG ROCM_VERSION=6.1.2
|
||||||
|
|
||||||
# Copy the minimal context we need to run the generate scripts
|
# Copy the minimal context we need to run the generate scripts
|
||||||
FROM scratch AS llm-code
|
FROM scratch AS llm-code
|
||||||
@@ -70,12 +70,12 @@ RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
|
|||||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
|
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
|
||||||
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
|
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
|
||||||
|
|
||||||
FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64
|
FROM --platform=linux/arm64 rockylinux:8 AS cpu-builder-arm64
|
||||||
ARG CMAKE_VERSION
|
ARG CMAKE_VERSION
|
||||||
ARG GOLANG_VERSION
|
ARG GOLANG_VERSION
|
||||||
COPY ./scripts/rh_linux_deps.sh /
|
COPY ./scripts/rh_linux_deps.sh /
|
||||||
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
|
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
|
||||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
||||||
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
|
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
|
||||||
ARG OLLAMA_CUSTOM_CPU_DEFS
|
ARG OLLAMA_CUSTOM_CPU_DEFS
|
||||||
ARG CGO_CFLAGS
|
ARG CGO_CFLAGS
|
||||||
|
@@ -293,6 +293,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
||||||
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
||||||
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
|
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
|
||||||
|
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
|
||||||
|
- [AI Studio](https://github.com/MindWorkAI/AI-Studio)
|
||||||
|
|
||||||
### Terminal
|
### Terminal
|
||||||
|
|
||||||
|
@@ -347,7 +347,16 @@ func (c *Client) Heartbeat(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Embeddings generates embeddings from a model.
|
// Embed generates embeddings from a model.
|
||||||
|
func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
|
||||||
|
var resp EmbedResponse
|
||||||
|
if err := c.do(ctx, http.MethodPost, "/api/embed", req, &resp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Embeddings generates an embedding from a model.
|
||||||
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
|
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
|
||||||
var resp EmbeddingResponse
|
var resp EmbeddingResponse
|
||||||
if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
|
if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
|
||||||
|
182
api/types.go
182
api/types.go
@@ -47,6 +47,9 @@ type GenerateRequest struct {
|
|||||||
// Prompt is the textual prompt to send to the model.
|
// Prompt is the textual prompt to send to the model.
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
|
||||||
|
// Suffix is the text that comes after the inserted text.
|
||||||
|
Suffix string `json:"suffix"`
|
||||||
|
|
||||||
// System overrides the model's default system message/prompt.
|
// System overrides the model's default system message/prompt.
|
||||||
System string `json:"system"`
|
System string `json:"system"`
|
||||||
|
|
||||||
@@ -97,6 +100,9 @@ type ChatRequest struct {
|
|||||||
// followin the request.
|
// followin the request.
|
||||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
|
|
||||||
|
// Tools is an optional list of tools the model has access to.
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
}
|
}
|
||||||
@@ -105,9 +111,46 @@ type ChatRequest struct {
|
|||||||
// role ("system", "user", or "assistant"), the content and an optional list
|
// role ("system", "user", or "assistant"), the content and an optional list
|
||||||
// of images.
|
// of images.
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content,omitempty"`
|
||||||
Images []ImageData `json:"images,omitempty"`
|
Images []ImageData `json:"images,omitempty"`
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
Function struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments map[string]any `json:"arguments"`
|
||||||
|
} `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Parameters struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Required []string `json:"required"`
|
||||||
|
Properties map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
} `json:"properties"`
|
||||||
|
} `json:"parameters"`
|
||||||
|
} `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||||
|
type Alias Message
|
||||||
|
var a Alias
|
||||||
|
if err := json.Unmarshal(b, &a); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*m = Message(a)
|
||||||
|
m.Role = strings.ToLower(m.Role)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||||
@@ -159,49 +202,42 @@ type Options struct {
|
|||||||
|
|
||||||
// Runner options which must be set when the model is loaded into memory
|
// Runner options which must be set when the model is loaded into memory
|
||||||
type Runner struct {
|
type Runner struct {
|
||||||
UseNUMA bool `json:"numa,omitempty"`
|
UseNUMA bool `json:"numa,omitempty"`
|
||||||
NumCtx int `json:"num_ctx,omitempty"`
|
NumCtx int `json:"num_ctx,omitempty"`
|
||||||
NumBatch int `json:"num_batch,omitempty"`
|
NumBatch int `json:"num_batch,omitempty"`
|
||||||
NumGPU int `json:"num_gpu,omitempty"`
|
NumGPU int `json:"num_gpu,omitempty"`
|
||||||
MainGPU int `json:"main_gpu,omitempty"`
|
MainGPU int `json:"main_gpu,omitempty"`
|
||||||
LowVRAM bool `json:"low_vram,omitempty"`
|
LowVRAM bool `json:"low_vram,omitempty"`
|
||||||
F16KV bool `json:"f16_kv,omitempty"`
|
F16KV bool `json:"f16_kv,omitempty"`
|
||||||
LogitsAll bool `json:"logits_all,omitempty"`
|
LogitsAll bool `json:"logits_all,omitempty"`
|
||||||
VocabOnly bool `json:"vocab_only,omitempty"`
|
VocabOnly bool `json:"vocab_only,omitempty"`
|
||||||
UseMMap TriState `json:"use_mmap,omitempty"`
|
UseMMap *bool `json:"use_mmap,omitempty"`
|
||||||
UseMLock bool `json:"use_mlock,omitempty"`
|
UseMLock bool `json:"use_mlock,omitempty"`
|
||||||
NumThread int `json:"num_thread,omitempty"`
|
NumThread int `json:"num_thread,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TriState int
|
// EmbedRequest is the request passed to [Client.Embed].
|
||||||
|
type EmbedRequest struct {
|
||||||
|
// Model is the model name.
|
||||||
|
Model string `json:"model"`
|
||||||
|
|
||||||
const (
|
// Input is the input to embed.
|
||||||
TriStateUndefined TriState = -1
|
Input any `json:"input"`
|
||||||
TriStateFalse TriState = 0
|
|
||||||
TriStateTrue TriState = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
func (b *TriState) UnmarshalJSON(data []byte) error {
|
// KeepAlive controls how long the model will stay loaded in memory following
|
||||||
var v bool
|
// this request.
|
||||||
if err := json.Unmarshal(data, &v); err != nil {
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
return err
|
|
||||||
}
|
Truncate *bool `json:"truncate,omitempty"`
|
||||||
if v {
|
|
||||||
*b = TriStateTrue
|
// Options lists model-specific options.
|
||||||
}
|
Options map[string]interface{} `json:"options"`
|
||||||
*b = TriStateFalse
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *TriState) MarshalJSON() ([]byte, error) {
|
// EmbedResponse is the response from [Client.Embed].
|
||||||
if *b == TriStateUndefined {
|
type EmbedResponse struct {
|
||||||
return nil, nil
|
Model string `json:"model"`
|
||||||
}
|
Embeddings [][]float32 `json:"embeddings"`
|
||||||
var v bool
|
|
||||||
if *b == TriStateTrue {
|
|
||||||
v = true
|
|
||||||
}
|
|
||||||
return json.Marshal(v)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// EmbeddingRequest is the request passed to [Client.Embeddings].
|
// EmbeddingRequest is the request passed to [Client.Embeddings].
|
||||||
@@ -250,8 +286,10 @@ type DeleteRequest struct {
|
|||||||
|
|
||||||
// ShowRequest is the request passed to [Client.Show].
|
// ShowRequest is the request passed to [Client.Show].
|
||||||
type ShowRequest struct {
|
type ShowRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
System string `json:"system"`
|
System string `json:"system"`
|
||||||
|
|
||||||
|
// Template is deprecated
|
||||||
Template string `json:"template"`
|
Template string `json:"template"`
|
||||||
Verbose bool `json:"verbose"`
|
Verbose bool `json:"verbose"`
|
||||||
|
|
||||||
@@ -345,6 +383,13 @@ type ProcessModelResponse struct {
|
|||||||
SizeVRAM int64 `json:"size_vram"`
|
SizeVRAM int64 `json:"size_vram"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RetrieveModelResponse struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
OwnedBy string `json:"owned_by"`
|
||||||
|
}
|
||||||
|
|
||||||
type TokenResponse struct {
|
type TokenResponse struct {
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
}
|
}
|
||||||
@@ -360,6 +405,9 @@ type GenerateResponse struct {
|
|||||||
// Response is the textual response itself.
|
// Response is the textual response itself.
|
||||||
Response string `json:"response"`
|
Response string `json:"response"`
|
||||||
|
|
||||||
|
// ToolCalls is the list of tools the model wants to call
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
|
||||||
// Done specifies if the response is complete.
|
// Done specifies if the response is complete.
|
||||||
Done bool `json:"done"`
|
Done bool `json:"done"`
|
||||||
|
|
||||||
@@ -437,19 +485,6 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if reflect.PointerTo(field.Type()) == reflect.TypeOf((*TriState)(nil)) {
|
|
||||||
val, ok := val.(bool)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("option %q must be of type boolean", key)
|
|
||||||
}
|
|
||||||
if val {
|
|
||||||
field.SetInt(int64(TriStateTrue))
|
|
||||||
} else {
|
|
||||||
field.SetInt(int64(TriStateFalse))
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
switch field.Kind() {
|
switch field.Kind() {
|
||||||
case reflect.Int:
|
case reflect.Int:
|
||||||
switch t := val.(type) {
|
switch t := val.(type) {
|
||||||
@@ -496,6 +531,17 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
|
|||||||
slice[i] = str
|
slice[i] = str
|
||||||
}
|
}
|
||||||
field.Set(reflect.ValueOf(slice))
|
field.Set(reflect.ValueOf(slice))
|
||||||
|
case reflect.Pointer:
|
||||||
|
var b bool
|
||||||
|
if field.Type() == reflect.TypeOf(&b) {
|
||||||
|
val, ok := val.(bool)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("option %q must be of type boolean", key)
|
||||||
|
}
|
||||||
|
field.Set(reflect.ValueOf(&val))
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("unknown type loading config params: %v %v", field.Kind(), field.Type())
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unknown type loading config params: %v", field.Kind())
|
return fmt.Errorf("unknown type loading config params: %v", field.Kind())
|
||||||
}
|
}
|
||||||
@@ -538,7 +584,7 @@ func DefaultOptions() Options {
|
|||||||
LowVRAM: false,
|
LowVRAM: false,
|
||||||
F16KV: true,
|
F16KV: true,
|
||||||
UseMLock: false,
|
UseMLock: false,
|
||||||
UseMMap: TriStateUndefined,
|
UseMMap: nil,
|
||||||
UseNUMA: false,
|
UseNUMA: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -608,19 +654,6 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
|||||||
} else {
|
} else {
|
||||||
field := valueOpts.FieldByName(opt.Name)
|
field := valueOpts.FieldByName(opt.Name)
|
||||||
if field.IsValid() && field.CanSet() {
|
if field.IsValid() && field.CanSet() {
|
||||||
if reflect.PointerTo(field.Type()) == reflect.TypeOf((*TriState)(nil)) {
|
|
||||||
boolVal, err := strconv.ParseBool(vals[0])
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid bool value %s", vals)
|
|
||||||
}
|
|
||||||
if boolVal {
|
|
||||||
out[key] = TriStateTrue
|
|
||||||
} else {
|
|
||||||
out[key] = TriStateFalse
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
switch field.Kind() {
|
switch field.Kind() {
|
||||||
case reflect.Float32:
|
case reflect.Float32:
|
||||||
floatVal, err := strconv.ParseFloat(vals[0], 32)
|
floatVal, err := strconv.ParseFloat(vals[0], 32)
|
||||||
@@ -648,6 +681,17 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
|||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
// TODO: only string slices are supported right now
|
// TODO: only string slices are supported right now
|
||||||
out[key] = vals
|
out[key] = vals
|
||||||
|
case reflect.Pointer:
|
||||||
|
var b bool
|
||||||
|
if field.Type() == reflect.TypeOf(&b) {
|
||||||
|
boolVal, err := strconv.ParseBool(vals[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid bool value %s", vals)
|
||||||
|
}
|
||||||
|
out[key] = &boolVal
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
|
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
|
||||||
}
|
}
|
||||||
|
@@ -108,25 +108,27 @@ func TestDurationMarshalUnmarshal(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUseMmapParsingFromJSON(t *testing.T) {
|
func TestUseMmapParsingFromJSON(t *testing.T) {
|
||||||
|
tr := true
|
||||||
|
fa := false
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
req string
|
req string
|
||||||
exp TriState
|
exp *bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Undefined",
|
name: "Undefined",
|
||||||
req: `{ }`,
|
req: `{ }`,
|
||||||
exp: TriStateUndefined,
|
exp: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "True",
|
name: "True",
|
||||||
req: `{ "use_mmap": true }`,
|
req: `{ "use_mmap": true }`,
|
||||||
exp: TriStateTrue,
|
exp: &tr,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "False",
|
name: "False",
|
||||||
req: `{ "use_mmap": false }`,
|
req: `{ "use_mmap": false }`,
|
||||||
exp: TriStateFalse,
|
exp: &fa,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,50 +146,52 @@ func TestUseMmapParsingFromJSON(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUseMmapFormatParams(t *testing.T) {
|
func TestUseMmapFormatParams(t *testing.T) {
|
||||||
|
tr := true
|
||||||
|
fa := false
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
req map[string][]string
|
req map[string][]string
|
||||||
exp TriState
|
exp *bool
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "True",
|
name: "True",
|
||||||
req: map[string][]string{
|
req: map[string][]string{
|
||||||
"use_mmap": []string{"true"},
|
"use_mmap": {"true"},
|
||||||
},
|
},
|
||||||
exp: TriStateTrue,
|
exp: &tr,
|
||||||
err: nil,
|
err: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "False",
|
name: "False",
|
||||||
req: map[string][]string{
|
req: map[string][]string{
|
||||||
"use_mmap": []string{"false"},
|
"use_mmap": {"false"},
|
||||||
},
|
},
|
||||||
exp: TriStateFalse,
|
exp: &fa,
|
||||||
err: nil,
|
err: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Numeric True",
|
name: "Numeric True",
|
||||||
req: map[string][]string{
|
req: map[string][]string{
|
||||||
"use_mmap": []string{"1"},
|
"use_mmap": {"1"},
|
||||||
},
|
},
|
||||||
exp: TriStateTrue,
|
exp: &tr,
|
||||||
err: nil,
|
err: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Numeric False",
|
name: "Numeric False",
|
||||||
req: map[string][]string{
|
req: map[string][]string{
|
||||||
"use_mmap": []string{"0"},
|
"use_mmap": {"0"},
|
||||||
},
|
},
|
||||||
exp: TriStateFalse,
|
exp: &fa,
|
||||||
err: nil,
|
err: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid string",
|
name: "invalid string",
|
||||||
req: map[string][]string{
|
req: map[string][]string{
|
||||||
"use_mmap": []string{"foo"},
|
"use_mmap": {"foo"},
|
||||||
},
|
},
|
||||||
exp: TriStateUndefined,
|
exp: nil,
|
||||||
err: fmt.Errorf("invalid bool value [foo]"),
|
err: fmt.Errorf("invalid bool value [foo]"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -195,12 +199,35 @@ func TestUseMmapFormatParams(t *testing.T) {
|
|||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
resp, err := FormatParams(test.req)
|
resp, err := FormatParams(test.req)
|
||||||
require.Equal(t, err, test.err)
|
require.Equal(t, test.err, err)
|
||||||
respVal, ok := resp["use_mmap"]
|
respVal, ok := resp["use_mmap"]
|
||||||
if test.exp != TriStateUndefined {
|
if test.exp != nil {
|
||||||
assert.True(t, ok, "resp: %v", resp)
|
assert.True(t, ok, "resp: %v", resp)
|
||||||
assert.Equal(t, test.exp, respVal)
|
assert.Equal(t, *test.exp, *respVal.(*bool))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMessage_UnmarshalJSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{`{"role": "USER", "content": "Hello!"}`, "user"},
|
||||||
|
{`{"role": "System", "content": "Initialization complete."}`, "system"},
|
||||||
|
{`{"role": "assistant", "content": "How can I help you?"}`, "assistant"},
|
||||||
|
{`{"role": "TOOl", "content": "Access granted."}`, "tool"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
var msg Message
|
||||||
|
if err := json.Unmarshal([]byte(test.input), &msg); err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.Role != test.expected {
|
||||||
|
t.Errorf("role not lowercased: got %v, expected %v", msg.Role, test.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -127,6 +127,10 @@ Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\models"
|
|||||||
Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history"
|
Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history"
|
||||||
; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved
|
; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved
|
||||||
|
|
||||||
|
[InstallDelete]
|
||||||
|
Type: filesandordirs; Name: "{%TEMP}\ollama*"
|
||||||
|
Type: filesandordirs; Name: "{%LOCALAPPDATA}\Programs\Ollama"
|
||||||
|
|
||||||
[Messages]
|
[Messages]
|
||||||
WizardReady=Ollama Windows Preview
|
WizardReady=Ollama Windows Preview
|
||||||
ReadyLabel1=%nLet's get you up and running with your own large language models.
|
ReadyLabel1=%nLet's get you up and running with your own large language models.
|
||||||
|
@@ -843,7 +843,6 @@ type runOptions struct {
|
|||||||
WordWrap bool
|
WordWrap bool
|
||||||
Format string
|
Format string
|
||||||
System string
|
System string
|
||||||
Template string
|
|
||||||
Images []api.ImageData
|
Images []api.ImageData
|
||||||
Options map[string]interface{}
|
Options map[string]interface{}
|
||||||
MultiModal bool
|
MultiModal bool
|
||||||
@@ -1037,7 +1036,6 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
|||||||
Images: opts.Images,
|
Images: opts.Images,
|
||||||
Format: opts.Format,
|
Format: opts.Format,
|
||||||
System: opts.System,
|
System: opts.System,
|
||||||
Template: opts.Template,
|
|
||||||
Options: opts.Options,
|
Options: opts.Options,
|
||||||
KeepAlive: opts.KeepAlive,
|
KeepAlive: opts.KeepAlive,
|
||||||
}
|
}
|
||||||
|
@@ -27,7 +27,6 @@ const (
|
|||||||
MultilineNone MultilineState = iota
|
MultilineNone MultilineState = iota
|
||||||
MultilinePrompt
|
MultilinePrompt
|
||||||
MultilineSystem
|
MultilineSystem
|
||||||
MultilineTemplate
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func loadModel(cmd *cobra.Command, opts *runOptions) error {
|
func loadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||||
@@ -94,7 +93,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||||
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
|
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
|
||||||
fmt.Fprintln(os.Stderr, " /set system <string> Set system message")
|
fmt.Fprintln(os.Stderr, " /set system <string> Set system message")
|
||||||
fmt.Fprintln(os.Stderr, " /set template <string> Set prompt template")
|
|
||||||
fmt.Fprintln(os.Stderr, " /set history Enable history")
|
fmt.Fprintln(os.Stderr, " /set history Enable history")
|
||||||
fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
|
fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
|
||||||
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
|
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
|
||||||
@@ -204,10 +202,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
|
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
|
||||||
fmt.Println("Set system message.")
|
fmt.Println("Set system message.")
|
||||||
sb.Reset()
|
sb.Reset()
|
||||||
case MultilineTemplate:
|
|
||||||
opts.Template = sb.String()
|
|
||||||
fmt.Println("Set prompt template.")
|
|
||||||
sb.Reset()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
multiline = MultilineNone
|
multiline = MultilineNone
|
||||||
@@ -326,17 +320,13 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
}
|
}
|
||||||
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
|
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
|
||||||
opts.Options[args[2]] = fp[args[2]]
|
opts.Options[args[2]] = fp[args[2]]
|
||||||
case "system", "template":
|
case "system":
|
||||||
if len(args) < 3 {
|
if len(args) < 3 {
|
||||||
usageSet()
|
usageSet()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if args[1] == "system" {
|
multiline = MultilineSystem
|
||||||
multiline = MultilineSystem
|
|
||||||
} else if args[1] == "template" {
|
|
||||||
multiline = MultilineTemplate
|
|
||||||
}
|
|
||||||
|
|
||||||
line := strings.Join(args[2:], " ")
|
line := strings.Join(args[2:], " ")
|
||||||
line, ok := strings.CutPrefix(line, `"""`)
|
line, ok := strings.CutPrefix(line, `"""`)
|
||||||
@@ -356,23 +346,17 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if args[1] == "system" {
|
opts.System = sb.String() // for display in modelfile
|
||||||
opts.System = sb.String() // for display in modelfile
|
newMessage := api.Message{Role: "system", Content: sb.String()}
|
||||||
newMessage := api.Message{Role: "system", Content: sb.String()}
|
// Check if the slice is not empty and the last message is from 'system'
|
||||||
// Check if the slice is not empty and the last message is from 'system'
|
if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
|
||||||
if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
|
// Replace the last message
|
||||||
// Replace the last message
|
opts.Messages[len(opts.Messages)-1] = newMessage
|
||||||
opts.Messages[len(opts.Messages)-1] = newMessage
|
} else {
|
||||||
} else {
|
opts.Messages = append(opts.Messages, newMessage)
|
||||||
opts.Messages = append(opts.Messages, newMessage)
|
|
||||||
}
|
|
||||||
fmt.Println("Set system message.")
|
|
||||||
sb.Reset()
|
|
||||||
} else if args[1] == "template" {
|
|
||||||
opts.Template = sb.String()
|
|
||||||
fmt.Println("Set prompt template.")
|
|
||||||
sb.Reset()
|
|
||||||
}
|
}
|
||||||
|
fmt.Println("Set system message.")
|
||||||
|
sb.Reset()
|
||||||
|
|
||||||
sb.Reset()
|
sb.Reset()
|
||||||
continue
|
continue
|
||||||
@@ -393,7 +377,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
req := &api.ShowRequest{
|
req := &api.ShowRequest{
|
||||||
Name: opts.Model,
|
Name: opts.Model,
|
||||||
System: opts.System,
|
System: opts.System,
|
||||||
Template: opts.Template,
|
|
||||||
Options: opts.Options,
|
Options: opts.Options,
|
||||||
}
|
}
|
||||||
resp, err := client.Show(cmd.Context(), req)
|
resp, err := client.Show(cmd.Context(), req)
|
||||||
@@ -437,12 +420,9 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
fmt.Println("No system message was specified for this model.")
|
fmt.Println("No system message was specified for this model.")
|
||||||
}
|
}
|
||||||
case "template":
|
case "template":
|
||||||
switch {
|
if resp.Template != "" {
|
||||||
case opts.Template != "":
|
|
||||||
fmt.Println(opts.Template + "\n")
|
|
||||||
case resp.Template != "":
|
|
||||||
fmt.Println(resp.Template)
|
fmt.Println(resp.Template)
|
||||||
default:
|
} else {
|
||||||
fmt.Println("No prompt template was specified for this model.")
|
fmt.Println("No prompt template was specified for this model.")
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@@ -536,10 +516,6 @@ func buildModelfile(opts runOptions) string {
|
|||||||
fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
|
fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.Template != "" {
|
|
||||||
fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template)
|
|
||||||
}
|
|
||||||
|
|
||||||
keys := make([]string, 0)
|
keys := make([]string, 0)
|
||||||
for k := range opts.Options {
|
for k := range opts.Options {
|
||||||
keys = append(keys, k)
|
keys = append(keys, k)
|
||||||
|
@@ -59,7 +59,6 @@ func TestModelfileBuilder(t *testing.T) {
|
|||||||
opts := runOptions{
|
opts := runOptions{
|
||||||
Model: "hork",
|
Model: "hork",
|
||||||
System: "You are part horse and part shark, but all hork. Do horklike things",
|
System: "You are part horse and part shark, but all hork. Do horklike things",
|
||||||
Template: "This is a template.",
|
|
||||||
Messages: []api.Message{
|
Messages: []api.Message{
|
||||||
{Role: "user", Content: "Hey there hork!"},
|
{Role: "user", Content: "Hey there hork!"},
|
||||||
{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
|
{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
|
||||||
@@ -75,7 +74,6 @@ func TestModelfileBuilder(t *testing.T) {
|
|||||||
mf := buildModelfile(opts)
|
mf := buildModelfile(opts)
|
||||||
expectedModelfile := `FROM {{.Model}}
|
expectedModelfile := `FROM {{.Model}}
|
||||||
SYSTEM """{{.System}}"""
|
SYSTEM """{{.System}}"""
|
||||||
TEMPLATE """{{.Template}}"""
|
|
||||||
PARAMETER penalize_newline false
|
PARAMETER penalize_newline false
|
||||||
PARAMETER seed 42
|
PARAMETER seed 42
|
||||||
PARAMETER stop [hi there]
|
PARAMETER stop [hi there]
|
||||||
@@ -97,7 +95,6 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark."""
|
|||||||
mf = buildModelfile(opts)
|
mf = buildModelfile(opts)
|
||||||
expectedModelfile = `FROM {{.ParentModel}}
|
expectedModelfile = `FROM {{.ParentModel}}
|
||||||
SYSTEM """{{.System}}"""
|
SYSTEM """{{.System}}"""
|
||||||
TEMPLATE """{{.Template}}"""
|
|
||||||
PARAMETER penalize_newline false
|
PARAMETER penalize_newline false
|
||||||
PARAMETER seed 42
|
PARAMETER seed 42
|
||||||
PARAMETER stop [hi there]
|
PARAMETER stop [hi there]
|
||||||
|
@@ -104,7 +104,7 @@ like to use. For example, to compile an optimized binary for an Intel i9-9880H,
|
|||||||
you might use:
|
you might use:
|
||||||
|
|
||||||
```
|
```
|
||||||
OLLAMA_CUSTOM_CPU_DEFS="-DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_F16C=on -DLLAMA_FMA=on" go generate ./...
|
OLLAMA_CUSTOM_CPU_DEFS="-DGGML_AVX=on -DGGML_AVX2=on -DGGML_F16C=on -DGGML_FMA=on" go generate ./...
|
||||||
go build .
|
go build .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@@ -266,8 +266,10 @@ If there is insufficient available memory to load a new model request while one
|
|||||||
|
|
||||||
Parallel request processing for a given model results in increasing the context size by the number of parallel requests. For example, a 2K context with 4 parallel requests will result in an 8K context and additional memory allocation.
|
Parallel request processing for a given model results in increasing the context size by the number of parallel requests. For example, a 2K context with 4 parallel requests will result in an 8K context and additional memory allocation.
|
||||||
|
|
||||||
The following server settings may be used to adjust how Ollama handles concurrent requests:
|
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
|
||||||
|
|
||||||
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 * the number of GPUs or 3 for CPU inference.
|
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 * the number of GPUs or 3 for CPU inference.
|
||||||
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
|
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
|
||||||
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
|
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
|
||||||
|
|
||||||
|
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
|
@@ -27,6 +27,11 @@ chat_completion = client.chat.completions.create(
|
|||||||
],
|
],
|
||||||
model='llama3',
|
model='llama3',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
completion = client.completions.create(
|
||||||
|
model="llama3",
|
||||||
|
prompt="Say this is a test"
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
### OpenAI JavaScript library
|
### OpenAI JavaScript library
|
||||||
@@ -45,6 +50,11 @@ const chatCompletion = await openai.chat.completions.create({
|
|||||||
messages: [{ role: 'user', content: 'Say this is a test' }],
|
messages: [{ role: 'user', content: 'Say this is a test' }],
|
||||||
model: 'llama3',
|
model: 'llama3',
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const completion = await openai.completions.create({
|
||||||
|
model: "llama3",
|
||||||
|
prompt: "Say this is a test.",
|
||||||
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
### `curl`
|
### `curl`
|
||||||
@@ -65,6 +75,13 @@ curl http://localhost:11434/v1/chat/completions \
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
}'
|
}'
|
||||||
|
|
||||||
|
curl http://localhost:11434/v1/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "llama3",
|
||||||
|
"prompt": "Say this is a test"
|
||||||
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
## Endpoints
|
## Endpoints
|
||||||
@@ -102,8 +119,71 @@ curl http://localhost:11434/v1/chat/completions \
|
|||||||
- [ ] `user`
|
- [ ] `user`
|
||||||
- [ ] `n`
|
- [ ] `n`
|
||||||
|
|
||||||
|
### `/v1/completions`
|
||||||
|
|
||||||
|
#### Supported features
|
||||||
|
|
||||||
|
- [x] Completions
|
||||||
|
- [x] Streaming
|
||||||
|
- [x] JSON mode
|
||||||
|
- [x] Reproducible outputs
|
||||||
|
- [ ] Logprobs
|
||||||
|
|
||||||
|
#### Supported request fields
|
||||||
|
|
||||||
|
- [x] `model`
|
||||||
|
- [x] `prompt`
|
||||||
|
- [x] `frequency_penalty`
|
||||||
|
- [x] `presence_penalty`
|
||||||
|
- [x] `seed`
|
||||||
|
- [x] `stop`
|
||||||
|
- [x] `stream`
|
||||||
|
- [x] `temperature`
|
||||||
|
- [x] `top_p`
|
||||||
|
- [x] `max_tokens`
|
||||||
|
- [x] `suffix`
|
||||||
|
- [ ] `best_of`
|
||||||
|
- [ ] `echo`
|
||||||
|
- [ ] `logit_bias`
|
||||||
|
- [ ] `user`
|
||||||
|
- [ ] `n`
|
||||||
|
|
||||||
#### Notes
|
#### Notes
|
||||||
|
|
||||||
|
- `prompt` currently only accepts a string
|
||||||
|
|
||||||
|
### `/v1/completions`
|
||||||
|
|
||||||
|
#### Supported features
|
||||||
|
|
||||||
|
- [x] Completions
|
||||||
|
- [x] Streaming
|
||||||
|
- [x] JSON mode
|
||||||
|
- [x] Reproducible outputs
|
||||||
|
- [ ] Logprobs
|
||||||
|
|
||||||
|
#### Supported request fields
|
||||||
|
|
||||||
|
- [x] `model`
|
||||||
|
- [x] `prompt`
|
||||||
|
- [x] `frequency_penalty`
|
||||||
|
- [x] `presence_penalty`
|
||||||
|
- [x] `seed`
|
||||||
|
- [x] `stop`
|
||||||
|
- [x] `stream`
|
||||||
|
- [x] `temperature`
|
||||||
|
- [x] `top_p`
|
||||||
|
- [x] `max_tokens`
|
||||||
|
- [ ] `best_of`
|
||||||
|
- [ ] `echo`
|
||||||
|
- [ ] `suffix`
|
||||||
|
- [ ] `logit_bias`
|
||||||
|
- [ ] `user`
|
||||||
|
- [ ] `n`
|
||||||
|
|
||||||
|
#### Notes
|
||||||
|
|
||||||
|
- `prompt` currently only accepts a string
|
||||||
- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
|
- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
|
||||||
|
|
||||||
## Models
|
## Models
|
||||||
|
@@ -70,14 +70,18 @@ curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION="0.1.29" sh
|
|||||||
|
|
||||||
If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/
|
If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/
|
||||||
|
|
||||||
## Container fails to run on NVIDIA GPU
|
## NVIDIA GPU Discovery
|
||||||
|
|
||||||
Make sure you've set up the container runtime first as described in [docker.md](./docker.md)
|
When Ollama starts up, it takes inventory of the GPUs present in the system to determine compatibility and how much VRAM is available. Sometimes this discovery can fail to find your GPUs. In general, running the latest driver will yield the best results.
|
||||||
|
|
||||||
Sometimes the container runtime can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
|
### Linux NVIDIA Troubleshooting
|
||||||
|
|
||||||
- Is the container runtime working? Try `docker run --gpus all ubuntu nvidia-smi` - if this doesn't work, Ollama wont be able to see your NVIDIA GPU.
|
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker.md](./docker.md)
|
||||||
- Is the uvm driver not loaded? `sudo nvidia-modprobe -u`
|
|
||||||
|
Sometimes the Ollama can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
|
||||||
|
|
||||||
|
- If you are using a container, is the container runtime working? Try `docker run --gpus all ubuntu nvidia-smi` - if this doesn't work, Ollama wont be able to see your NVIDIA GPU.
|
||||||
|
- Is the uvm driver loaded? `sudo nvidia-modprobe -u`
|
||||||
- Try reloading the nvidia_uvm driver - `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm`
|
- Try reloading the nvidia_uvm driver - `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm`
|
||||||
- Try rebooting
|
- Try rebooting
|
||||||
- Make sure you're running the latest nvidia drivers
|
- Make sure you're running the latest nvidia drivers
|
||||||
@@ -85,3 +89,8 @@ Sometimes the container runtime can have difficulties initializing the GPU. When
|
|||||||
If none of those resolve the problem, gather additional information and file an issue:
|
If none of those resolve the problem, gather additional information and file an issue:
|
||||||
- Set `CUDA_ERROR_LEVEL=50` and try again to get more diagnostic logs
|
- Set `CUDA_ERROR_LEVEL=50` and try again to get more diagnostic logs
|
||||||
- Check dmesg for any errors `sudo dmesg | grep -i nvrm` and `sudo dmesg | grep -i nvidia`
|
- Check dmesg for any errors `sudo dmesg | grep -i nvrm` and `sudo dmesg | grep -i nvidia`
|
||||||
|
|
||||||
|
|
||||||
|
## Windows Terminal Errors
|
||||||
|
|
||||||
|
Older versions of Windows 10 (e.g., 21H1) are known to have a bug where the standard terminal program does not display control characters correctly. This can result in a long string of strings like `←[?25h←[?25l` being displayed, sometimes erroring with `The parameter is incorrect` To resolve this problem, please update to Win 10 22H1 or newer.
|
||||||
|
@@ -19,7 +19,7 @@ Logs will often be helpful in diagnosing the problem (see
|
|||||||
|
|
||||||
## System Requirements
|
## System Requirements
|
||||||
|
|
||||||
* Windows 10 or newer, Home or Pro
|
* Windows 10 22H2 or newer, Home or Pro
|
||||||
* NVIDIA 452.39 or newer Drivers if you have an NVIDIA card
|
* NVIDIA 452.39 or newer Drivers if you have an NVIDIA card
|
||||||
* AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card
|
* AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card
|
||||||
|
|
||||||
|
@@ -4,12 +4,14 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OllamaHost struct {
|
type OllamaHost struct {
|
||||||
@@ -34,17 +36,17 @@ var (
|
|||||||
// Set via OLLAMA_HOST in the environment
|
// Set via OLLAMA_HOST in the environment
|
||||||
Host *OllamaHost
|
Host *OllamaHost
|
||||||
// Set via OLLAMA_KEEP_ALIVE in the environment
|
// Set via OLLAMA_KEEP_ALIVE in the environment
|
||||||
KeepAlive string
|
KeepAlive time.Duration
|
||||||
// Set via OLLAMA_LLM_LIBRARY in the environment
|
// Set via OLLAMA_LLM_LIBRARY in the environment
|
||||||
LLMLibrary string
|
LLMLibrary string
|
||||||
// Set via OLLAMA_MAX_LOADED_MODELS in the environment
|
// Set via OLLAMA_MAX_LOADED_MODELS in the environment
|
||||||
MaxRunners int
|
MaxRunners int
|
||||||
// Set via OLLAMA_MAX_QUEUE in the environment
|
// Set via OLLAMA_MAX_QUEUE in the environment
|
||||||
MaxQueuedRequests int
|
MaxQueuedRequests int
|
||||||
// Set via OLLAMA_MODELS in the environment
|
|
||||||
ModelsDir string
|
|
||||||
// Set via OLLAMA_MAX_VRAM in the environment
|
// Set via OLLAMA_MAX_VRAM in the environment
|
||||||
MaxVRAM uint64
|
MaxVRAM uint64
|
||||||
|
// Set via OLLAMA_MODELS in the environment
|
||||||
|
ModelsDir string
|
||||||
// Set via OLLAMA_NOHISTORY in the environment
|
// Set via OLLAMA_NOHISTORY in the environment
|
||||||
NoHistory bool
|
NoHistory bool
|
||||||
// Set via OLLAMA_NOPRUNE in the environment
|
// Set via OLLAMA_NOPRUNE in the environment
|
||||||
@@ -132,6 +134,7 @@ func init() {
|
|||||||
NumParallel = 0 // Autoselect
|
NumParallel = 0 // Autoselect
|
||||||
MaxRunners = 0 // Autoselect
|
MaxRunners = 0 // Autoselect
|
||||||
MaxQueuedRequests = 512
|
MaxQueuedRequests = 512
|
||||||
|
KeepAlive = 5 * time.Minute
|
||||||
|
|
||||||
LoadConfig()
|
LoadConfig()
|
||||||
}
|
}
|
||||||
@@ -266,7 +269,10 @@ func LoadConfig() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
KeepAlive = clean("OLLAMA_KEEP_ALIVE")
|
ka := clean("OLLAMA_KEEP_ALIVE")
|
||||||
|
if ka != "" {
|
||||||
|
loadKeepAlive(ka)
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
ModelsDir, err = getModelsDir()
|
ModelsDir, err = getModelsDir()
|
||||||
@@ -344,3 +350,24 @@ func getOllamaHost() (*OllamaHost, error) {
|
|||||||
Port: port,
|
Port: port,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func loadKeepAlive(ka string) {
|
||||||
|
v, err := strconv.Atoi(ka)
|
||||||
|
if err != nil {
|
||||||
|
d, err := time.ParseDuration(ka)
|
||||||
|
if err == nil {
|
||||||
|
if d < 0 {
|
||||||
|
KeepAlive = time.Duration(math.MaxInt64)
|
||||||
|
} else {
|
||||||
|
KeepAlive = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
d := time.Duration(v) * time.Second
|
||||||
|
if d < 0 {
|
||||||
|
KeepAlive = time.Duration(math.MaxInt64)
|
||||||
|
} else {
|
||||||
|
KeepAlive = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -2,8 +2,10 @@ package envconfig
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -23,6 +25,21 @@ func TestConfig(t *testing.T) {
|
|||||||
t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
|
t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
|
||||||
LoadConfig()
|
LoadConfig()
|
||||||
require.True(t, FlashAttention)
|
require.True(t, FlashAttention)
|
||||||
|
t.Setenv("OLLAMA_KEEP_ALIVE", "")
|
||||||
|
LoadConfig()
|
||||||
|
require.Equal(t, 5*time.Minute, KeepAlive)
|
||||||
|
t.Setenv("OLLAMA_KEEP_ALIVE", "3")
|
||||||
|
LoadConfig()
|
||||||
|
require.Equal(t, 3*time.Second, KeepAlive)
|
||||||
|
t.Setenv("OLLAMA_KEEP_ALIVE", "1h")
|
||||||
|
LoadConfig()
|
||||||
|
require.Equal(t, 1*time.Hour, KeepAlive)
|
||||||
|
t.Setenv("OLLAMA_KEEP_ALIVE", "-1s")
|
||||||
|
LoadConfig()
|
||||||
|
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
|
||||||
|
t.Setenv("OLLAMA_KEEP_ALIVE", "-1")
|
||||||
|
LoadConfig()
|
||||||
|
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClientFromEnvironment(t *testing.T) {
|
func TestClientFromEnvironment(t *testing.T) {
|
||||||
|
3
go.mod
3
go.mod
@@ -18,6 +18,7 @@ require (
|
|||||||
require (
|
require (
|
||||||
github.com/agnivade/levenshtein v1.1.1
|
github.com/agnivade/levenshtein v1.1.1
|
||||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||||
|
github.com/google/go-cmp v0.6.0
|
||||||
github.com/mattn/go-runewidth v0.0.14
|
github.com/mattn/go-runewidth v0.0.14
|
||||||
github.com/nlpodyssey/gopickle v0.3.0
|
github.com/nlpodyssey/gopickle v0.3.0
|
||||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||||
@@ -71,7 +72,7 @@ require (
|
|||||||
golang.org/x/net v0.25.0 // indirect
|
golang.org/x/net v0.25.0 // indirect
|
||||||
golang.org/x/sys v0.20.0
|
golang.org/x/sys v0.20.0
|
||||||
golang.org/x/term v0.20.0
|
golang.org/x/term v0.20.0
|
||||||
golang.org/x/text v0.15.0 // indirect
|
golang.org/x/text v0.15.0
|
||||||
google.golang.org/protobuf v1.34.1
|
google.golang.org/protobuf v1.34.1
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
@@ -49,9 +49,17 @@ func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func commonAMDValidateLibDir() (string, error) {
|
func commonAMDValidateLibDir() (string, error) {
|
||||||
// We try to favor system paths first, so that we can wire up the subprocess to use
|
// Favor our bundled version
|
||||||
// the system version. Only use our bundled version if the system version doesn't work
|
|
||||||
// This gives users a more recovery options if versions have subtle problems at runtime
|
// Installer payload location if we're running the installed binary
|
||||||
|
exe, err := os.Executable()
|
||||||
|
if err == nil {
|
||||||
|
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
|
||||||
|
if rocmLibUsable(rocmTargetDir) {
|
||||||
|
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
|
||||||
|
return rocmTargetDir, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Prefer explicit HIP env var
|
// Prefer explicit HIP env var
|
||||||
hipPath := os.Getenv("HIP_PATH")
|
hipPath := os.Getenv("HIP_PATH")
|
||||||
@@ -87,14 +95,5 @@ func commonAMDValidateLibDir() (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Installer payload location if we're running the installed binary
|
|
||||||
exe, err := os.Executable()
|
|
||||||
if err == nil {
|
|
||||||
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
|
|
||||||
if rocmLibUsable(rocmTargetDir) {
|
|
||||||
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
|
|
||||||
return rocmTargetDir, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
||||||
}
|
}
|
||||||
|
@@ -84,9 +84,8 @@ func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("hipDriverGetVersion", "version", version)
|
slog.Debug("hipDriverGetVersion", "version", version)
|
||||||
// TODO - this isn't actually right, but the docs claim hipDriverGetVersion isn't accurate anyway...
|
driverMajor = version / 10000000
|
||||||
driverMajor = version / 1000
|
driverMinor = (version - (driverMajor * 10000000)) / 100000
|
||||||
driverMinor = (version - (driverMajor * 1000)) / 10
|
|
||||||
|
|
||||||
return driverMajor, driverMinor, nil
|
return driverMajor, driverMinor, nil
|
||||||
}
|
}
|
||||||
|
@@ -22,8 +22,8 @@ const (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
// Used to validate if the given ROCm lib is usable
|
// Used to validate if the given ROCm lib is usable
|
||||||
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
|
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // This is not sufficient to discern v5 vs v6
|
||||||
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
|
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\6.1\\bin"} // TODO glob?
|
||||||
)
|
)
|
||||||
|
|
||||||
func AMDGetGPUInfo() []RocmGPUInfo {
|
func AMDGetGPUInfo() []RocmGPUInfo {
|
||||||
@@ -35,12 +35,11 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
|||||||
}
|
}
|
||||||
defer hl.Release()
|
defer hl.Release()
|
||||||
|
|
||||||
// TODO - this reports incorrect version information, so omitting for now
|
driverMajor, driverMinor, err := hl.AMDDriverVersion()
|
||||||
// driverMajor, driverMinor, err := hl.AMDDriverVersion()
|
if err != nil {
|
||||||
// if err != nil {
|
// For now this is benign, but we may eventually need to fail compatibility checks
|
||||||
// // For now this is benign, but we may eventually need to fail compatibility checks
|
slog.Debug("error looking up amd driver version", "error", err)
|
||||||
// slog.Debug("error looking up amd driver version", "error", err)
|
}
|
||||||
// }
|
|
||||||
|
|
||||||
// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
|
// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
|
||||||
count := hl.HipGetDeviceCount()
|
count := hl.HipGetDeviceCount()
|
||||||
@@ -132,10 +131,8 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
|||||||
MinimumMemory: rocmMinimumMemory,
|
MinimumMemory: rocmMinimumMemory,
|
||||||
Name: name,
|
Name: name,
|
||||||
Compute: gfx,
|
Compute: gfx,
|
||||||
|
DriverMajor: driverMajor,
|
||||||
// TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve
|
DriverMinor: driverMinor,
|
||||||
// DriverMajor: driverMajor,
|
|
||||||
// DriverMinor: driverMinor,
|
|
||||||
},
|
},
|
||||||
index: i,
|
index: i,
|
||||||
}
|
}
|
||||||
|
53
gpu/gpu.go
53
gpu/gpu.go
@@ -202,7 +202,7 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
if !bootstrapped {
|
if !bootstrapped {
|
||||||
slog.Debug("Detecting GPUs")
|
slog.Info("looking for compatible GPUs")
|
||||||
needRefresh = false
|
needRefresh = false
|
||||||
cpuCapability = GetCPUCapability()
|
cpuCapability = GetCPUCapability()
|
||||||
var memInfo C.mem_info_t
|
var memInfo C.mem_info_t
|
||||||
@@ -274,6 +274,28 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
gpuInfo.DriverMajor = driverMajor
|
gpuInfo.DriverMajor = driverMajor
|
||||||
gpuInfo.DriverMinor = driverMinor
|
gpuInfo.DriverMinor = driverMinor
|
||||||
|
|
||||||
|
// query the management library as well so we can record any skew between the two
|
||||||
|
// which represents overhead on the GPU we must set aside on subsequent updates
|
||||||
|
if cHandles.nvml != nil {
|
||||||
|
C.nvml_get_free(*cHandles.nvml, C.int(gpuInfo.index), &memInfo.free, &memInfo.total, &memInfo.used)
|
||||||
|
if memInfo.err != nil {
|
||||||
|
slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
|
||||||
|
C.free(unsafe.Pointer(memInfo.err))
|
||||||
|
} else {
|
||||||
|
if memInfo.free != 0 && uint64(memInfo.free) > gpuInfo.FreeMemory {
|
||||||
|
gpuInfo.OSOverhead = uint64(memInfo.free) - gpuInfo.FreeMemory
|
||||||
|
slog.Info("detected OS VRAM overhead",
|
||||||
|
"id", gpuInfo.ID,
|
||||||
|
"library", gpuInfo.Library,
|
||||||
|
"compute", gpuInfo.Compute,
|
||||||
|
"driver", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor),
|
||||||
|
"name", gpuInfo.Name,
|
||||||
|
"overhead", format.HumanBytes2(gpuInfo.OSOverhead),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
||||||
cudaGPUs = append(cudaGPUs, gpuInfo)
|
cudaGPUs = append(cudaGPUs, gpuInfo)
|
||||||
}
|
}
|
||||||
@@ -320,6 +342,9 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
|
|
||||||
rocmGPUs = AMDGetGPUInfo()
|
rocmGPUs = AMDGetGPUInfo()
|
||||||
bootstrapped = true
|
bootstrapped = true
|
||||||
|
if len(cudaGPUs) == 0 && len(rocmGPUs) == 0 && len(oneapiGPUs) == 0 {
|
||||||
|
slog.Info("no compatible GPUs were discovered")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// For detected GPUs, load library if not loaded
|
// For detected GPUs, load library if not loaded
|
||||||
@@ -335,14 +360,17 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
"before",
|
"before",
|
||||||
"total", format.HumanBytes2(cpus[0].TotalMemory),
|
"total", format.HumanBytes2(cpus[0].TotalMemory),
|
||||||
"free", format.HumanBytes2(cpus[0].FreeMemory),
|
"free", format.HumanBytes2(cpus[0].FreeMemory),
|
||||||
|
"free_swap", format.HumanBytes2(cpus[0].FreeSwap),
|
||||||
),
|
),
|
||||||
slog.Group(
|
slog.Group(
|
||||||
"now",
|
"now",
|
||||||
"total", format.HumanBytes2(mem.TotalMemory),
|
"total", format.HumanBytes2(mem.TotalMemory),
|
||||||
"free", format.HumanBytes2(mem.FreeMemory),
|
"free", format.HumanBytes2(mem.FreeMemory),
|
||||||
|
"free_swap", format.HumanBytes2(mem.FreeSwap),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
cpus[0].FreeMemory = mem.FreeMemory
|
cpus[0].FreeMemory = mem.FreeMemory
|
||||||
|
cpus[0].FreeSwap = mem.FreeSwap
|
||||||
}
|
}
|
||||||
|
|
||||||
var memInfo C.mem_info_t
|
var memInfo C.mem_info_t
|
||||||
@@ -371,9 +399,14 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
slog.Warn("error looking up nvidia GPU memory")
|
slog.Warn("error looking up nvidia GPU memory")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if cHandles.nvml != nil && gpu.OSOverhead > 0 {
|
||||||
|
// When using the management library update based on recorded overhead
|
||||||
|
memInfo.free -= C.uint64_t(gpu.OSOverhead)
|
||||||
|
}
|
||||||
slog.Debug("updating cuda memory data",
|
slog.Debug("updating cuda memory data",
|
||||||
"gpu", gpu.ID,
|
"gpu", gpu.ID,
|
||||||
"name", gpu.Name,
|
"name", gpu.Name,
|
||||||
|
"overhead", format.HumanBytes2(gpu.OSOverhead),
|
||||||
slog.Group(
|
slog.Group(
|
||||||
"before",
|
"before",
|
||||||
"total", format.HumanBytes2(gpu.TotalMemory),
|
"total", format.HumanBytes2(gpu.TotalMemory),
|
||||||
@@ -514,7 +547,23 @@ func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
|
|||||||
defer C.free(unsafe.Pointer(lib))
|
defer C.free(unsafe.Pointer(lib))
|
||||||
C.nvcuda_init(lib, &resp)
|
C.nvcuda_init(lib, &resp)
|
||||||
if resp.err != nil {
|
if resp.err != nil {
|
||||||
slog.Debug("Unable to load nvcuda", "library", libPath, "error", C.GoString(resp.err))
|
// Decide what log level based on the type of error message to help users understand why
|
||||||
|
msg := C.GoString(resp.err)
|
||||||
|
switch resp.cudaErr {
|
||||||
|
case C.CUDA_ERROR_INSUFFICIENT_DRIVER, C.CUDA_ERROR_SYSTEM_DRIVER_MISMATCH:
|
||||||
|
slog.Warn("version mismatch between driver and cuda driver library - reboot or upgrade may be required", "library", libPath, "error", msg)
|
||||||
|
case C.CUDA_ERROR_NO_DEVICE:
|
||||||
|
slog.Info("no nvidia devices detected", "library", libPath)
|
||||||
|
case C.CUDA_ERROR_UNKNOWN:
|
||||||
|
slog.Warn("unknown error initializing cuda driver library", "library", libPath, "error", msg)
|
||||||
|
slog.Warn("see https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for more information")
|
||||||
|
default:
|
||||||
|
if strings.Contains(msg, "wrong ELF class") {
|
||||||
|
slog.Debug("skipping 32bit library", "library", libPath)
|
||||||
|
} else {
|
||||||
|
slog.Info("unable to load cuda driver library", "library", libPath, "error", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
C.free(unsafe.Pointer(resp.err))
|
C.free(unsafe.Pointer(resp.err))
|
||||||
} else {
|
} else {
|
||||||
return int(resp.num_devices), &resp.ch, libPath
|
return int(resp.num_devices), &resp.ch, libPath
|
||||||
|
@@ -56,7 +56,8 @@ func GetCPUInfo() GpuInfoList {
|
|||||||
func GetCPUMem() (memInfo, error) {
|
func GetCPUMem() (memInfo, error) {
|
||||||
return memInfo{
|
return memInfo{
|
||||||
TotalMemory: uint64(C.getPhysicalMemory()),
|
TotalMemory: uint64(C.getPhysicalMemory()),
|
||||||
FreeMemory: 0,
|
FreeMemory: uint64(C.getFreeMemory()),
|
||||||
|
// FreeSwap omitted as Darwin uses dynamic paging
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -2,3 +2,4 @@
|
|||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
uint64_t getRecommendedMaxVRAM();
|
uint64_t getRecommendedMaxVRAM();
|
||||||
uint64_t getPhysicalMemory();
|
uint64_t getPhysicalMemory();
|
||||||
|
uint64_t getFreeMemory();
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
// go:build darwin
|
#import <Foundation/Foundation.h>
|
||||||
|
#import <mach/mach.h>
|
||||||
#include "gpu_info_darwin.h"
|
#include "gpu_info_darwin.h"
|
||||||
|
|
||||||
uint64_t getRecommendedMaxVRAM() {
|
uint64_t getRecommendedMaxVRAM() {
|
||||||
@@ -8,6 +9,27 @@ uint64_t getRecommendedMaxVRAM() {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getPhysicalMemory returns the total physical memory in bytes
|
||||||
uint64_t getPhysicalMemory() {
|
uint64_t getPhysicalMemory() {
|
||||||
return [[NSProcessInfo processInfo] physicalMemory];
|
return [NSProcessInfo processInfo].physicalMemory;
|
||||||
|
}
|
||||||
|
|
||||||
|
// getFreeMemory returns the total free memory in bytes, including inactive
|
||||||
|
// memory that can be reclaimed by the system.
|
||||||
|
uint64_t getFreeMemory() {
|
||||||
|
mach_port_t host_port = mach_host_self();
|
||||||
|
mach_msg_type_number_t host_size = sizeof(vm_statistics64_data_t) / sizeof(integer_t);
|
||||||
|
vm_size_t pagesize;
|
||||||
|
vm_statistics64_data_t vm_stat;
|
||||||
|
|
||||||
|
host_page_size(host_port, &pagesize);
|
||||||
|
if (host_statistics64(host_port, HOST_VM_INFO64, (host_info64_t)&vm_stat, &host_size) != KERN_SUCCESS) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t free_memory = (uint64_t)vm_stat.free_count * pagesize;
|
||||||
|
free_memory += (uint64_t)vm_stat.speculative_count * pagesize;
|
||||||
|
free_memory += (uint64_t)vm_stat.inactive_count * pagesize;
|
||||||
|
|
||||||
|
return free_memory;
|
||||||
}
|
}
|
||||||
|
@@ -7,6 +7,7 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
|
|||||||
CUresult ret;
|
CUresult ret;
|
||||||
resp->err = NULL;
|
resp->err = NULL;
|
||||||
resp->num_devices = 0;
|
resp->num_devices = 0;
|
||||||
|
resp->cudaErr = CUDA_SUCCESS;
|
||||||
const int buflen = 256;
|
const int buflen = 256;
|
||||||
char buf[buflen + 1];
|
char buf[buflen + 1];
|
||||||
int i;
|
int i;
|
||||||
@@ -38,6 +39,7 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
|
|||||||
nvcuda_lib_path, msg);
|
nvcuda_lib_path, msg);
|
||||||
free(msg);
|
free(msg);
|
||||||
resp->err = strdup(buf);
|
resp->err = strdup(buf);
|
||||||
|
resp->cudaErr = -1;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,6 +54,7 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
|
|||||||
msg);
|
msg);
|
||||||
free(msg);
|
free(msg);
|
||||||
resp->err = strdup(buf);
|
resp->err = strdup(buf);
|
||||||
|
resp->cudaErr = -1;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -61,12 +64,9 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
|
|||||||
LOG(resp->ch.verbose, "cuInit err: %d\n", ret);
|
LOG(resp->ch.verbose, "cuInit err: %d\n", ret);
|
||||||
UNLOAD_LIBRARY(resp->ch.handle);
|
UNLOAD_LIBRARY(resp->ch.handle);
|
||||||
resp->ch.handle = NULL;
|
resp->ch.handle = NULL;
|
||||||
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
|
snprintf(buf, buflen, "cuda driver library init failure: %d", ret);
|
||||||
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
snprintf(buf, buflen, "nvcuda init failure: %d", ret);
|
|
||||||
resp->err = strdup(buf);
|
resp->err = strdup(buf);
|
||||||
|
resp->cudaErr = ret;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,6 +91,7 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
|
|||||||
resp->ch.handle = NULL;
|
resp->ch.handle = NULL;
|
||||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
||||||
resp->err = strdup(buf);
|
resp->err = strdup(buf);
|
||||||
|
resp->cudaErr = ret;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -106,13 +107,13 @@ void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
|
|||||||
CUuuid uuid = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
|
CUuuid uuid = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
|
||||||
|
|
||||||
if (h.handle == NULL) {
|
if (h.handle == NULL) {
|
||||||
resp->err = strdup("nvcuda handle isn't initialized");
|
resp->err = strdup("cuda driver library handle isn't initialized");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = (*h.cuDeviceGet)(&device, i);
|
ret = (*h.cuDeviceGet)(&device, i);
|
||||||
if (ret != CUDA_SUCCESS) {
|
if (ret != CUDA_SUCCESS) {
|
||||||
snprintf(buf, buflen, "nvcuda device failed to initialize");
|
snprintf(buf, buflen, "cuda driver library device failed to initialize");
|
||||||
resp->err = strdup(buf);
|
resp->err = strdup(buf);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -168,14 +169,14 @@ void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
|
|||||||
// To get memory we have to set (and release) a context
|
// To get memory we have to set (and release) a context
|
||||||
ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
|
ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
|
||||||
if (ret != CUDA_SUCCESS) {
|
if (ret != CUDA_SUCCESS) {
|
||||||
snprintf(buf, buflen, "nvcuda failed to get device context %d", ret);
|
snprintf(buf, buflen, "cuda driver library failed to get device context %d", ret);
|
||||||
resp->err = strdup(buf);
|
resp->err = strdup(buf);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = (*h.cuMemGetInfo_v2)(&memInfo.free, &memInfo.total);
|
ret = (*h.cuMemGetInfo_v2)(&memInfo.free, &memInfo.total);
|
||||||
if (ret != CUDA_SUCCESS) {
|
if (ret != CUDA_SUCCESS) {
|
||||||
snprintf(buf, buflen, "nvcuda device memory info lookup failure %d", ret);
|
snprintf(buf, buflen, "cuda driver library device memory info lookup failure %d", ret);
|
||||||
resp->err = strdup(buf);
|
resp->err = strdup(buf);
|
||||||
// Best effort on failure...
|
// Best effort on failure...
|
||||||
(*h.cuCtxDestroy)(ctx);
|
(*h.cuCtxDestroy)(ctx);
|
||||||
@@ -193,7 +194,7 @@ void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
|
|||||||
|
|
||||||
ret = (*h.cuCtxDestroy)(ctx);
|
ret = (*h.cuCtxDestroy)(ctx);
|
||||||
if (ret != CUDA_SUCCESS) {
|
if (ret != CUDA_SUCCESS) {
|
||||||
LOG(1, "nvcuda failed to release device context %d", ret);
|
LOG(1, "cuda driver library failed to release device context %d", ret);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -206,7 +207,7 @@ void nvcuda_get_free(nvcuda_handle_t h, int i, uint64_t *free, uint64_t *total)
|
|||||||
|
|
||||||
ret = (*h.cuDeviceGet)(&device, i);
|
ret = (*h.cuDeviceGet)(&device, i);
|
||||||
if (ret != CUDA_SUCCESS) {
|
if (ret != CUDA_SUCCESS) {
|
||||||
LOG(1, "nvcuda device failed to initialize");
|
LOG(1, "cuda driver library device failed to initialize");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -214,13 +215,13 @@ void nvcuda_get_free(nvcuda_handle_t h, int i, uint64_t *free, uint64_t *total)
|
|||||||
// To get memory we have to set (and release) a context
|
// To get memory we have to set (and release) a context
|
||||||
ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
|
ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
|
||||||
if (ret != CUDA_SUCCESS) {
|
if (ret != CUDA_SUCCESS) {
|
||||||
LOG(1, "nvcuda failed to get device context %d", ret);
|
LOG(1, "cuda driver library failed to get device context %d", ret);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = (*h.cuMemGetInfo_v2)(free, total);
|
ret = (*h.cuMemGetInfo_v2)(free, total);
|
||||||
if (ret != CUDA_SUCCESS) {
|
if (ret != CUDA_SUCCESS) {
|
||||||
LOG(1, "nvcuda device memory info lookup failure %d", ret);
|
LOG(1, "cuda driver library device memory info lookup failure %d", ret);
|
||||||
// Best effort on failure...
|
// Best effort on failure...
|
||||||
(*h.cuCtxDestroy)(ctx);
|
(*h.cuCtxDestroy)(ctx);
|
||||||
return;
|
return;
|
||||||
@@ -228,12 +229,12 @@ void nvcuda_get_free(nvcuda_handle_t h, int i, uint64_t *free, uint64_t *total)
|
|||||||
|
|
||||||
ret = (*h.cuCtxDestroy)(ctx);
|
ret = (*h.cuCtxDestroy)(ctx);
|
||||||
if (ret != CUDA_SUCCESS) {
|
if (ret != CUDA_SUCCESS) {
|
||||||
LOG(1, "nvcuda failed to release device context %d", ret);
|
LOG(1, "cuda driver library failed to release device context %d", ret);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void nvcuda_release(nvcuda_handle_t h) {
|
void nvcuda_release(nvcuda_handle_t h) {
|
||||||
LOG(h.verbose, "releasing nvcuda library\n");
|
LOG(h.verbose, "releasing cuda driver library\n");
|
||||||
UNLOAD_LIBRARY(h.handle);
|
UNLOAD_LIBRARY(h.handle);
|
||||||
// TODO and other context release logic?
|
// TODO and other context release logic?
|
||||||
h.handle = NULL;
|
h.handle = NULL;
|
||||||
|
@@ -7,9 +7,12 @@
|
|||||||
typedef enum cudaError_enum {
|
typedef enum cudaError_enum {
|
||||||
CUDA_SUCCESS = 0,
|
CUDA_SUCCESS = 0,
|
||||||
CUDA_ERROR_INVALID_VALUE = 1,
|
CUDA_ERROR_INVALID_VALUE = 1,
|
||||||
CUDA_ERROR_MEMORY_ALLOCATION = 2,
|
CUDA_ERROR_OUT_OF_MEMORY = 2,
|
||||||
CUDA_ERROR_NOT_INITIALIZED = 3,
|
CUDA_ERROR_NOT_INITIALIZED = 3,
|
||||||
CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
|
CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
|
||||||
|
CUDA_ERROR_NO_DEVICE = 100,
|
||||||
|
CUDA_ERROR_SYSTEM_DRIVER_MISMATCH = 803,
|
||||||
|
CUDA_ERROR_UNKNOWN = 999,
|
||||||
// Other values omitted for now...
|
// Other values omitted for now...
|
||||||
} CUresult;
|
} CUresult;
|
||||||
|
|
||||||
@@ -64,6 +67,7 @@ typedef struct nvcuda_init_resp {
|
|||||||
char *err; // If err is non-null handle is invalid
|
char *err; // If err is non-null handle is invalid
|
||||||
nvcuda_handle_t ch;
|
nvcuda_handle_t ch;
|
||||||
int num_devices;
|
int num_devices;
|
||||||
|
CUresult cudaErr;
|
||||||
} nvcuda_init_resp_t;
|
} nvcuda_init_resp_t;
|
||||||
|
|
||||||
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp);
|
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp);
|
||||||
|
@@ -50,7 +50,7 @@ var OneapiMgmtName = "libze_intel_gpu.so"
|
|||||||
|
|
||||||
func GetCPUMem() (memInfo, error) {
|
func GetCPUMem() (memInfo, error) {
|
||||||
var mem memInfo
|
var mem memInfo
|
||||||
var total, available, free, buffers, cached uint64
|
var total, available, free, buffers, cached, freeSwap uint64
|
||||||
f, err := os.Open("/proc/meminfo")
|
f, err := os.Open("/proc/meminfo")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return mem, err
|
return mem, err
|
||||||
@@ -70,20 +70,21 @@ func GetCPUMem() (memInfo, error) {
|
|||||||
_, err = fmt.Sscanf(line, "Buffers:%d", &buffers)
|
_, err = fmt.Sscanf(line, "Buffers:%d", &buffers)
|
||||||
case strings.HasPrefix(line, "Cached:"):
|
case strings.HasPrefix(line, "Cached:"):
|
||||||
_, err = fmt.Sscanf(line, "Cached:%d", &cached)
|
_, err = fmt.Sscanf(line, "Cached:%d", &cached)
|
||||||
|
case strings.HasPrefix(line, "SwapFree:"):
|
||||||
|
_, err = fmt.Sscanf(line, "SwapFree:%d", &freeSwap)
|
||||||
default:
|
default:
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return mem, err
|
return mem, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if total > 0 && available > 0 {
|
|
||||||
mem.TotalMemory = total * format.KibiByte
|
|
||||||
mem.FreeMemory = available * format.KibiByte
|
|
||||||
return mem, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
mem.TotalMemory = total * format.KibiByte
|
mem.TotalMemory = total * format.KibiByte
|
||||||
mem.FreeMemory = (free + buffers + cached) * format.KibiByte
|
mem.FreeSwap = freeSwap * format.KibiByte
|
||||||
|
if available > 0 {
|
||||||
|
mem.FreeMemory = available * format.KibiByte
|
||||||
|
} else {
|
||||||
|
mem.FreeMemory = (free + buffers + cached) * format.KibiByte
|
||||||
|
}
|
||||||
return mem, nil
|
return mem, nil
|
||||||
}
|
}
|
||||||
|
@@ -51,5 +51,5 @@ func GetCPUMem() (memInfo, error) {
|
|||||||
if r1 == 0 {
|
if r1 == 0 {
|
||||||
return memInfo{}, fmt.Errorf("GlobalMemoryStatusEx failed: %w", err)
|
return memInfo{}, fmt.Errorf("GlobalMemoryStatusEx failed: %w", err)
|
||||||
}
|
}
|
||||||
return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys}, nil
|
return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys, FreeSwap: memStatus.AvailPageFile}, nil
|
||||||
}
|
}
|
||||||
|
@@ -10,6 +10,7 @@ import (
|
|||||||
type memInfo struct {
|
type memInfo struct {
|
||||||
TotalMemory uint64 `json:"total_memory,omitempty"`
|
TotalMemory uint64 `json:"total_memory,omitempty"`
|
||||||
FreeMemory uint64 `json:"free_memory,omitempty"`
|
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||||
|
FreeSwap uint64 `json:"free_swap,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Beginning of an `ollama info` command
|
// Beginning of an `ollama info` command
|
||||||
@@ -52,7 +53,8 @@ type CPUInfo struct {
|
|||||||
|
|
||||||
type CudaGPUInfo struct {
|
type CudaGPUInfo struct {
|
||||||
GpuInfo
|
GpuInfo
|
||||||
index int //nolint:unused,nolintlint
|
OSOverhead uint64 // Memory overhead between the driver library and management library
|
||||||
|
index int //nolint:unused,nolintlint
|
||||||
}
|
}
|
||||||
type CudaGPUInfoList []CudaGPUInfo
|
type CudaGPUInfoList []CudaGPUInfo
|
||||||
|
|
||||||
|
152
integration/embed_test.go
Normal file
152
integration/embed_test.go
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAllMiniLMEmbed(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := embedTestHelper(ctx, t, req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings) != 1 {
|
||||||
|
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings[0]) != 384 {
|
||||||
|
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Embeddings[0][0] != 0.010071031 {
|
||||||
|
t.Fatalf("expected 0.010071031, got %f", res.Embeddings[0][0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: []string{"why is the sky blue?", "why is the grass green?"},
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := embedTestHelper(ctx, t, req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings) != 2 {
|
||||||
|
t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings[0]) != 384 {
|
||||||
|
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Embeddings[0][0] != 0.010071031 || res.Embeddings[1][0] != -0.009802706 {
|
||||||
|
t.Fatalf("expected 0.010071031 and -0.009802706, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllMiniLmEmbedTruncate(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
truncTrue, truncFalse := true, false
|
||||||
|
|
||||||
|
type testReq struct {
|
||||||
|
Name string
|
||||||
|
Request api.EmbedRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
reqs := []testReq{
|
||||||
|
{
|
||||||
|
Name: "Target Truncation",
|
||||||
|
Request: api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Default Truncate",
|
||||||
|
Request: api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
Options: map[string]any{"num_ctx": 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Explicit Truncate",
|
||||||
|
Request: api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
Truncate: &truncTrue,
|
||||||
|
Options: map[string]any{"num_ctx": 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
res := make(map[string]*api.EmbedResponse)
|
||||||
|
|
||||||
|
for _, req := range reqs {
|
||||||
|
response, err := embedTestHelper(ctx, t, req.Request)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
res[req.Name] = response
|
||||||
|
}
|
||||||
|
|
||||||
|
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
|
||||||
|
t.Fatal("expected default request to truncate correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
|
||||||
|
t.Fatal("expected default request and truncate true request to be the same")
|
||||||
|
}
|
||||||
|
|
||||||
|
// check that truncate set to false returns an error if context length is exceeded
|
||||||
|
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
Truncate: &truncFalse,
|
||||||
|
Options: map[string]any{"num_ctx": 1},
|
||||||
|
})
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := client.Embed(ctx, &req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, nil
|
||||||
|
}
|
25
llm/ext_server/CMakeLists.txt
vendored
25
llm/ext_server/CMakeLists.txt
vendored
@@ -1,14 +1,13 @@
|
|||||||
|
set(TARGET ollama_llama_server)
|
||||||
set(TARGET ollama_llama_server)
|
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
|
||||||
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
|
||||||
add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
|
install(TARGETS ${TARGET} RUNTIME)
|
||||||
install(TARGETS ${TARGET} RUNTIME)
|
target_compile_definitions(${TARGET} PRIVATE
|
||||||
target_compile_definitions(${TARGET} PRIVATE
|
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
|
||||||
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
|
)
|
||||||
)
|
target_link_libraries(${TARGET} PRIVATE ggml llama common llava ${CMAKE_THREAD_LIBS_INIT})
|
||||||
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
|
if (WIN32)
|
||||||
if (WIN32)
|
TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
|
||||||
TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
|
endif()
|
||||||
endif()
|
|
||||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
108
llm/ext_server/server.cpp
vendored
108
llm/ext_server/server.cpp
vendored
@@ -1382,12 +1382,50 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string common_prefix(const std::string& str1, const std::string& str2) {
|
||||||
|
auto mismatch_pair = std::mismatch(str1.begin(), str1.end(), str2.begin());
|
||||||
|
return std::string(str1.begin(), mismatch_pair.first);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the slot that has the greatest common prefix
|
||||||
|
server_slot *prefix_slot(const json &prompt) {
|
||||||
|
if (!prompt.is_string()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string prompt_str = prompt.get<std::string>();
|
||||||
|
server_slot *slot = nullptr;
|
||||||
|
size_t longest = 0;
|
||||||
|
|
||||||
|
for (server_slot &s : slots) {
|
||||||
|
if (s.available() && s.prompt.is_string()) {
|
||||||
|
std::string s_prompt = s.prompt.get<std::string>();
|
||||||
|
std::string prefix = common_prefix(s_prompt, prompt_str);
|
||||||
|
|
||||||
|
if (prefix.size() > longest) {
|
||||||
|
slot = &s;
|
||||||
|
longest = prefix.size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!slot) {
|
||||||
|
return get_slot(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_DEBUG("slot with common prefix found", {{
|
||||||
|
"slot_id", slot->id,
|
||||||
|
"characters", longest
|
||||||
|
}});
|
||||||
|
return slot;
|
||||||
|
}
|
||||||
|
|
||||||
void process_single_task(task_server& task)
|
void process_single_task(task_server& task)
|
||||||
{
|
{
|
||||||
switch (task.type)
|
switch (task.type)
|
||||||
{
|
{
|
||||||
case TASK_TYPE_COMPLETION: {
|
case TASK_TYPE_COMPLETION: {
|
||||||
server_slot *slot = get_slot(json_value(task.data, "slot_id", -1));
|
server_slot *slot = prefix_slot(task.data["prompt"]);
|
||||||
if (slot == nullptr)
|
if (slot == nullptr)
|
||||||
{
|
{
|
||||||
// if no slot is available, we defer this task for processing later
|
// if no slot is available, we defer this task for processing later
|
||||||
@@ -1650,22 +1688,8 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
||||||
|
|
||||||
char buf[256];
|
|
||||||
llama_model_meta_val_str(model, "general.architecture", buf, 256);
|
|
||||||
bool gemma2 = strcmp(buf, "gemma2") == 0;
|
|
||||||
|
|
||||||
int32_t truncate_at = slot.n_ctx;
|
|
||||||
|
|
||||||
// truncate at 2/3 of the context length for gemma2 models
|
|
||||||
// as they do not support context shifts (from the sliding window implementation).
|
|
||||||
// this way, prompts that almost fit the context length can still generate a full
|
|
||||||
// response without a sudden stop from hitting the context limit
|
|
||||||
if (gemma2) {
|
|
||||||
truncate_at = 2 * slot.n_ctx / 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
// if input prompt is too big, truncate it, if group attention self-extend is disabled
|
// if input prompt is too big, truncate it, if group attention self-extend is disabled
|
||||||
if (slot.ga_n == 1 && slot.n_prompt_tokens >= truncate_at)
|
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx)
|
||||||
{
|
{
|
||||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||||
const int n_shift = n_left / 2;
|
const int n_shift = n_left / 2;
|
||||||
@@ -1693,19 +1717,6 @@ struct llama_server_context
|
|||||||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Models with sliding window attention do not work with context shifts, so
|
|
||||||
// limit their prediction to the context length
|
|
||||||
if (gemma2) {
|
|
||||||
int32_t limit = slot.n_ctx - slot.n_prompt_tokens;
|
|
||||||
slot.n_predict = limit;
|
|
||||||
slot.params.n_predict = limit;
|
|
||||||
LOG_INFO("model does not support sliding window, limiting generation", {
|
|
||||||
{"n_ctx", slot.n_ctx},
|
|
||||||
{"n_prompt_tokens", slot.n_prompt_tokens},
|
|
||||||
{"n_predict", slot.n_predict}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!slot.params.cache_prompt)
|
if (!slot.params.cache_prompt)
|
||||||
{
|
{
|
||||||
llama_sampling_reset(slot.ctx_sampling);
|
llama_sampling_reset(slot.ctx_sampling);
|
||||||
@@ -1732,7 +1743,7 @@ struct llama_server_context
|
|||||||
slot.n_past -= 1;
|
slot.n_past -= 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.n_prompt_tokens_processed = slot.n_prompt_tokens - slot.n_past;
|
slot.n_prompt_tokens_processed = slot.n_prompt_tokens;
|
||||||
|
|
||||||
if (slot.ga_n != 1)
|
if (slot.ga_n != 1)
|
||||||
{
|
{
|
||||||
@@ -3177,26 +3188,33 @@ int main(int argc, char **argv) {
|
|||||||
prompt = "";
|
prompt = "";
|
||||||
}
|
}
|
||||||
|
|
||||||
json image_data;
|
if (prompt.size() == 1) {
|
||||||
if (body.count("image_data") != 0) {
|
prompt = prompt[0];
|
||||||
image_data = body["image_data"];
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
image_data = "";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// create and queue the task
|
// create and queue the task
|
||||||
const int task_id = llama.queue_tasks.get_new_id();
|
json responses;
|
||||||
llama.queue_results.add_waiting_task_id(task_id);
|
{
|
||||||
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, true, -1);
|
const int id_task = llama.queue_tasks.get_new_id();
|
||||||
|
llama.queue_results.add_waiting_task_id(id_task);
|
||||||
|
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
|
||||||
|
|
||||||
// get the result
|
// get the result
|
||||||
task_result result = llama.queue_results.recv(task_id);
|
task_result result = llama.queue_results.recv(id_task);
|
||||||
llama.queue_results.remove_waiting_task_id(task_id);
|
llama.queue_results.remove_waiting_task_id(id_task);
|
||||||
|
if (result.error) {
|
||||||
|
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
||||||
|
}
|
||||||
|
|
||||||
// send the result
|
responses = result.result_json.value("results", std::vector<json>{result.result_json});
|
||||||
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
json embeddings = json::array();
|
||||||
|
for (auto & elem : responses) {
|
||||||
|
embeddings.push_back(elem.at("embedding"));
|
||||||
|
}
|
||||||
|
// send the result
|
||||||
|
json embedding_res = json{{"embedding", embeddings}};
|
||||||
|
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
|
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
|
||||||
|
@@ -18,16 +18,16 @@ sign() {
|
|||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DLLAMA_METAL_MACOSX_VERSION_MIN=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DLLAMA_METAL_EMBED_LIBRARY=on -DLLAMA_OPENMP=off"
|
COMMON_DARWIN_DEFS="-DBUILD_SHARED_LIBS=off -DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DLLAMA_METAL_MACOSX_VERSION_MIN=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DGGML_METAL_EMBED_LIBRARY=on -DGGML_OPENMP=off"
|
||||||
|
|
||||||
case "${GOARCH}" in
|
case "${GOARCH}" in
|
||||||
"amd64")
|
"amd64")
|
||||||
COMMON_CPU_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=off -DLLAMA_NATIVE=off"
|
COMMON_CPU_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DGGML_METAL=off -DGGML_NATIVE=off"
|
||||||
|
|
||||||
# Static build for linking into the Go binary
|
# Static build for linking into the Go binary
|
||||||
init_vars
|
init_vars
|
||||||
CMAKE_TARGETS="--target llama --target ggml"
|
CMAKE_TARGETS="--target llama --target ggml"
|
||||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DBUILD_SHARED_LIBS=off -DLLAMA_BLAS=off -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_BLAS=off -DGGML_ACCELERATE=off -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
|
||||||
BUILD_DIR="../build/darwin/${ARCH}_static"
|
BUILD_DIR="../build/darwin/${ARCH}_static"
|
||||||
echo "Building static library"
|
echo "Building static library"
|
||||||
build
|
build
|
||||||
@@ -37,7 +37,7 @@ case "${GOARCH}" in
|
|||||||
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
|
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
|
||||||
#
|
#
|
||||||
init_vars
|
init_vars
|
||||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_BLAS=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_ACCELERATE=off -DGGML_BLAS=off -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
|
||||||
BUILD_DIR="../build/darwin/${ARCH}/cpu"
|
BUILD_DIR="../build/darwin/${ARCH}/cpu"
|
||||||
echo "Building LCD CPU"
|
echo "Building LCD CPU"
|
||||||
build
|
build
|
||||||
@@ -49,7 +49,7 @@ case "${GOARCH}" in
|
|||||||
# Approximately 400% faster than LCD on same CPU
|
# Approximately 400% faster than LCD on same CPU
|
||||||
#
|
#
|
||||||
init_vars
|
init_vars
|
||||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_BLAS=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_ACCELERATE=off -DGGML_BLAS=off -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
|
||||||
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx"
|
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx"
|
||||||
echo "Building AVX CPU"
|
echo "Building AVX CPU"
|
||||||
build
|
build
|
||||||
@@ -61,7 +61,7 @@ case "${GOARCH}" in
|
|||||||
# Approximately 10% faster than AVX on same CPU
|
# Approximately 10% faster than AVX on same CPU
|
||||||
#
|
#
|
||||||
init_vars
|
init_vars
|
||||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=on -DLLAMA_BLAS=off -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
|
CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_ACCELERATE=on -DGGML_BLAS=off -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on ${CMAKE_DEFS}"
|
||||||
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx2"
|
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx2"
|
||||||
echo "Building AVX2 CPU"
|
echo "Building AVX2 CPU"
|
||||||
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
|
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
|
||||||
@@ -75,14 +75,14 @@ case "${GOARCH}" in
|
|||||||
# Static build for linking into the Go binary
|
# Static build for linking into the Go binary
|
||||||
init_vars
|
init_vars
|
||||||
CMAKE_TARGETS="--target llama --target ggml"
|
CMAKE_TARGETS="--target llama --target ggml"
|
||||||
CMAKE_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DLLAMA_BLAS=off -DCMAKE_SYSTEM_NAME=Darwin -DBUILD_SHARED_LIBS=off -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=off -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} ${CMAKE_DEFS}"
|
||||||
BUILD_DIR="../build/darwin/${ARCH}_static"
|
BUILD_DIR="../build/darwin/${ARCH}_static"
|
||||||
echo "Building static library"
|
echo "Building static library"
|
||||||
build
|
build
|
||||||
|
|
||||||
if [ -z "$OLLAMA_SKIP_METAL_GENERATE" ]; then
|
if [ -z "$OLLAMA_SKIP_METAL_GENERATE" ]; then
|
||||||
init_vars
|
init_vars
|
||||||
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
|
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} ${CMAKE_DEFS}"
|
||||||
BUILD_DIR="../build/darwin/${ARCH}/metal"
|
BUILD_DIR="../build/darwin/${ARCH}/metal"
|
||||||
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
|
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
|
||||||
build
|
build
|
||||||
|
@@ -51,7 +51,7 @@ if [ -z "${CUDACXX}" ]; then
|
|||||||
export CUDACXX=$(command -v nvcc)
|
export CUDACXX=$(command -v nvcc)
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
COMMON_CMAKE_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off -DLLAMA_OPENMP=off"
|
COMMON_CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DCMAKE_POSITION_INDEPENDENT_CODE=on -DGGML_NATIVE=off -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_OPENMP=off"
|
||||||
source $(dirname $0)/gen_common.sh
|
source $(dirname $0)/gen_common.sh
|
||||||
init_vars
|
init_vars
|
||||||
git_module_setup
|
git_module_setup
|
||||||
@@ -64,7 +64,7 @@ if [ -z "${OLLAMA_SKIP_STATIC_GENERATE}" -o "${OLLAMA_CPU_TARGET}" = "static" ];
|
|||||||
# Static build for linking into the Go binary
|
# Static build for linking into the Go binary
|
||||||
init_vars
|
init_vars
|
||||||
CMAKE_TARGETS="--target llama --target ggml"
|
CMAKE_TARGETS="--target llama --target ggml"
|
||||||
CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off -DLLAMA_OPENMP=off ${CMAKE_DEFS}"
|
CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DGGML_NATIVE=off -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_OPENMP=off ${CMAKE_DEFS}"
|
||||||
BUILD_DIR="../build/linux/${ARCH}_static"
|
BUILD_DIR="../build/linux/${ARCH}_static"
|
||||||
echo "Building static library"
|
echo "Building static library"
|
||||||
build
|
build
|
||||||
@@ -77,29 +77,29 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
|
|||||||
if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then
|
if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then
|
||||||
init_vars
|
init_vars
|
||||||
echo "OLLAMA_CUSTOM_CPU_DEFS=\"${OLLAMA_CUSTOM_CPU_DEFS}\""
|
echo "OLLAMA_CUSTOM_CPU_DEFS=\"${OLLAMA_CUSTOM_CPU_DEFS}\""
|
||||||
CMAKE_DEFS="${OLLAMA_CUSTOM_CPU_DEFS} -DCMAKE_POSITION_INDEPENDENT_CODE=on ${CMAKE_DEFS}"
|
CMAKE_DEFS="${OLLAMA_CUSTOM_CPU_DEFS} -DBUILD_SHARED_LIBS=off -DCMAKE_POSITION_INDEPENDENT_CODE=on ${CMAKE_DEFS}"
|
||||||
BUILD_DIR="../build/linux/${ARCH}/cpu"
|
BUILD_DIR="../build/linux/${ARCH}/cpu"
|
||||||
echo "Building custom CPU"
|
echo "Building custom CPU"
|
||||||
build
|
build
|
||||||
compress
|
compress
|
||||||
else
|
else
|
||||||
# Darwin Rosetta x86 emulation does NOT support AVX, AVX2, AVX512
|
# Darwin Rosetta x86 emulation does NOT support AVX, AVX2, AVX512
|
||||||
# -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
|
# -DGGML_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
|
||||||
# -DLLAMA_F16C -- 2012 Intel Ivy Bridge & AMD 2011 Bulldozer (No significant improvement over just AVX)
|
# -DGGML_F16C -- 2012 Intel Ivy Bridge & AMD 2011 Bulldozer (No significant improvement over just AVX)
|
||||||
# -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
|
# -DGGML_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
|
||||||
# -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
|
# -DGGML_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
|
||||||
# Note: the following seem to yield slower results than AVX2 - ymmv
|
# Note: the following seem to yield slower results than AVX2 - ymmv
|
||||||
# -DLLAMA_AVX512 -- 2017 Intel Skylake and High End DeskTop (HEDT)
|
# -DGGML_AVX512 -- 2017 Intel Skylake and High End DeskTop (HEDT)
|
||||||
# -DLLAMA_AVX512_VBMI -- 2018 Intel Cannon Lake
|
# -DGGML_AVX512_VBMI -- 2018 Intel Cannon Lake
|
||||||
# -DLLAMA_AVX512_VNNI -- 2021 Intel Alder Lake
|
# -DGGML_AVX512_VNNI -- 2021 Intel Alder Lake
|
||||||
|
|
||||||
COMMON_CPU_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DLLAMA_NATIVE=off -DLLAMA_OPENMP=off"
|
COMMON_CPU_DEFS="-DBUILD_SHARED_LIBS=off -DCMAKE_POSITION_INDEPENDENT_CODE=on -DGGML_NATIVE=off -DGGML_OPENMP=off"
|
||||||
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu" ]; then
|
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu" ]; then
|
||||||
#
|
#
|
||||||
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
|
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
|
||||||
#
|
#
|
||||||
init_vars
|
init_vars
|
||||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
|
||||||
BUILD_DIR="../build/linux/${ARCH}/cpu"
|
BUILD_DIR="../build/linux/${ARCH}/cpu"
|
||||||
echo "Building LCD CPU"
|
echo "Building LCD CPU"
|
||||||
build
|
build
|
||||||
@@ -116,7 +116,7 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
|
|||||||
# Approximately 400% faster than LCD on same CPU
|
# Approximately 400% faster than LCD on same CPU
|
||||||
#
|
#
|
||||||
init_vars
|
init_vars
|
||||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off ${CMAKE_DEFS}"
|
||||||
BUILD_DIR="../build/linux/${ARCH}/cpu_avx"
|
BUILD_DIR="../build/linux/${ARCH}/cpu_avx"
|
||||||
echo "Building AVX CPU"
|
echo "Building AVX CPU"
|
||||||
build
|
build
|
||||||
@@ -129,7 +129,7 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
|
|||||||
# Approximately 10% faster than AVX on same CPU
|
# Approximately 10% faster than AVX on same CPU
|
||||||
#
|
#
|
||||||
init_vars
|
init_vars
|
||||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
|
CMAKE_DEFS="${COMMON_CPU_DEFS} -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on ${CMAKE_DEFS}"
|
||||||
BUILD_DIR="../build/linux/${ARCH}/cpu_avx2"
|
BUILD_DIR="../build/linux/${ARCH}/cpu_avx2"
|
||||||
echo "Building AVX2 CPU"
|
echo "Building AVX2 CPU"
|
||||||
build
|
build
|
||||||
@@ -170,15 +170,15 @@ if [ -z "${OLLAMA_SKIP_CUDA_GENERATE}" -a -d "${CUDA_LIB_DIR}" ]; then
|
|||||||
#
|
#
|
||||||
# CUDA compute < 6.0 lacks proper FP16 support on ARM.
|
# CUDA compute < 6.0 lacks proper FP16 support on ARM.
|
||||||
# Disabling has minimal performance effect while maintaining compatibility.
|
# Disabling has minimal performance effect while maintaining compatibility.
|
||||||
ARM64_DEFS="-DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_CUDA_F16=off"
|
ARM64_DEFS="-DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_CUDA_F16=off"
|
||||||
fi
|
fi
|
||||||
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
|
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
|
||||||
if [ -n "${OLLAMA_CUSTOM_CUDA_DEFS}" ]; then
|
if [ -n "${OLLAMA_CUSTOM_CUDA_DEFS}" ]; then
|
||||||
echo "OLLAMA_CUSTOM_CUDA_DEFS=\"${OLLAMA_CUSTOM_CUDA_DEFS}\""
|
echo "OLLAMA_CUSTOM_CUDA_DEFS=\"${OLLAMA_CUSTOM_CUDA_DEFS}\""
|
||||||
CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
|
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
|
||||||
echo "Building custom CUDA GPU"
|
echo "Building custom CUDA GPU"
|
||||||
else
|
else
|
||||||
CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}"
|
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}"
|
||||||
fi
|
fi
|
||||||
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}"
|
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}"
|
||||||
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
|
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
|
||||||
@@ -216,7 +216,7 @@ if [ -z "${OLLAMA_SKIP_ONEAPI_GENERATE}" -a -d "${ONEAPI_ROOT}" ]; then
|
|||||||
init_vars
|
init_vars
|
||||||
source ${ONEAPI_ROOT}/setvars.sh --force # set up environment variables for oneAPI
|
source ${ONEAPI_ROOT}/setvars.sh --force # set up environment variables for oneAPI
|
||||||
CC=icx
|
CC=icx
|
||||||
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_SYCL=ON -DLLAMA_SYCL_F16=OFF"
|
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL=ON -DGGML_SYCL_F16=OFF"
|
||||||
BUILD_DIR="../build/linux/${ARCH}/oneapi"
|
BUILD_DIR="../build/linux/${ARCH}/oneapi"
|
||||||
EXTRA_LIBS="-fsycl -Wl,-rpath,${ONEAPI_ROOT}/compiler/latest/lib,-rpath,${ONEAPI_ROOT}/mkl/latest/lib,-rpath,${ONEAPI_ROOT}/tbb/latest/lib,-rpath,${ONEAPI_ROOT}/compiler/latest/opt/oclfpga/linux64/lib -lOpenCL -lmkl_core -lmkl_sycl_blas -lmkl_intel_ilp64 -lmkl_tbb_thread -ltbb"
|
EXTRA_LIBS="-fsycl -Wl,-rpath,${ONEAPI_ROOT}/compiler/latest/lib,-rpath,${ONEAPI_ROOT}/mkl/latest/lib,-rpath,${ONEAPI_ROOT}/tbb/latest/lib,-rpath,${ONEAPI_ROOT}/compiler/latest/opt/oclfpga/linux64/lib -lOpenCL -lmkl_core -lmkl_sycl_blas -lmkl_intel_ilp64 -lmkl_tbb_thread -ltbb"
|
||||||
DEBUG_FLAGS="" # icx compiles with -O0 if we pass -g, so we must remove it
|
DEBUG_FLAGS="" # icx compiles with -O0 if we pass -g, so we must remove it
|
||||||
@@ -254,7 +254,7 @@ if [ -z "${OLLAMA_SKIP_ROCM_GENERATE}" -a -d "${ROCM_PATH}" ]; then
|
|||||||
ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocblas.so.*.*.????? | cut -f5 -d. || true)
|
ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocblas.so.*.*.????? | cut -f5 -d. || true)
|
||||||
fi
|
fi
|
||||||
init_vars
|
init_vars
|
||||||
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
|
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DGGML_HIPBLAS=on -DLLAMA_CUDA_NO_PEER_COPY=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
|
||||||
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
|
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
|
||||||
if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
|
if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
|
||||||
echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""
|
echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""
|
||||||
|
@@ -6,18 +6,9 @@ function amdGPUs {
|
|||||||
if ($env:AMDGPU_TARGETS) {
|
if ($env:AMDGPU_TARGETS) {
|
||||||
return $env:AMDGPU_TARGETS
|
return $env:AMDGPU_TARGETS
|
||||||
}
|
}
|
||||||
# TODO - load from some common data file for linux + windows build consistency
|
# Current supported rocblas list from ROCm v6.1.2 on windows
|
||||||
$GPU_LIST = @(
|
$GPU_LIST = @(
|
||||||
"gfx900"
|
|
||||||
"gfx906:xnack-"
|
"gfx906:xnack-"
|
||||||
"gfx908:xnack-"
|
|
||||||
"gfx90a:xnack+"
|
|
||||||
"gfx90a:xnack-"
|
|
||||||
"gfx940"
|
|
||||||
"gfx941"
|
|
||||||
"gfx942"
|
|
||||||
"gfx1010"
|
|
||||||
"gfx1012"
|
|
||||||
"gfx1030"
|
"gfx1030"
|
||||||
"gfx1100"
|
"gfx1100"
|
||||||
"gfx1101"
|
"gfx1101"
|
||||||
@@ -39,8 +30,8 @@ function init_vars {
|
|||||||
}
|
}
|
||||||
$script:cmakeDefs = @(
|
$script:cmakeDefs = @(
|
||||||
"-DBUILD_SHARED_LIBS=on",
|
"-DBUILD_SHARED_LIBS=on",
|
||||||
"-DLLAMA_NATIVE=off",
|
"-DGGML_NATIVE=off",
|
||||||
"-DLLAMA_OPENMP=off"
|
"-DGGML_OPENMP=off"
|
||||||
)
|
)
|
||||||
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
|
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
|
||||||
$script:ARCH = $Env:PROCESSOR_ARCHITECTURE.ToLower()
|
$script:ARCH = $Env:PROCESSOR_ARCHITECTURE.ToLower()
|
||||||
@@ -182,9 +173,9 @@ function cleanup {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
|
# -DGGML_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
|
||||||
# -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
|
# -DGGML_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
|
||||||
# -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
|
# -DGGML_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
|
||||||
|
|
||||||
|
|
||||||
function build_static() {
|
function build_static() {
|
||||||
@@ -204,13 +195,13 @@ function build_static() {
|
|||||||
"-DCMAKE_C_COMPILER=gcc.exe",
|
"-DCMAKE_C_COMPILER=gcc.exe",
|
||||||
"-DCMAKE_CXX_COMPILER=g++.exe",
|
"-DCMAKE_CXX_COMPILER=g++.exe",
|
||||||
"-DBUILD_SHARED_LIBS=off",
|
"-DBUILD_SHARED_LIBS=off",
|
||||||
"-DLLAMA_NATIVE=off",
|
"-DGGML_NATIVE=off",
|
||||||
"-DLLAMA_AVX=off",
|
"-DGGML_AVX=off",
|
||||||
"-DLLAMA_AVX2=off",
|
"-DGGML_AVX2=off",
|
||||||
"-DLLAMA_AVX512=off",
|
"-DGGML_AVX512=off",
|
||||||
"-DLLAMA_F16C=off",
|
"-DGGML_F16C=off",
|
||||||
"-DLLAMA_FMA=off",
|
"-DGGML_FMA=off",
|
||||||
"-DLLAMA_OPENMP=off")
|
"-DGGML_OPENMP=off")
|
||||||
$script:buildDir="../build/windows/${script:ARCH}_static"
|
$script:buildDir="../build/windows/${script:ARCH}_static"
|
||||||
write-host "Building static library"
|
write-host "Building static library"
|
||||||
build
|
build
|
||||||
@@ -224,7 +215,7 @@ function build_cpu($gen_arch) {
|
|||||||
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu"))) {
|
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu"))) {
|
||||||
# remaining llama.cpp builds use MSVC
|
# remaining llama.cpp builds use MSVC
|
||||||
init_vars
|
init_vars
|
||||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", $gen_arch, "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
$script:cmakeDefs = $script:commonCpuDefs + @("-A", $gen_arch, "-DGGML_AVX=off", "-DGGML_AVX2=off", "-DGGML_AVX512=off", "-DGGML_FMA=off", "-DGGML_F16C=off") + $script:cmakeDefs
|
||||||
$script:buildDir="../build/windows/${script:ARCH}/cpu"
|
$script:buildDir="../build/windows/${script:ARCH}/cpu"
|
||||||
$script:distDir="$script:DIST_BASE\cpu"
|
$script:distDir="$script:DIST_BASE\cpu"
|
||||||
write-host "Building LCD CPU"
|
write-host "Building LCD CPU"
|
||||||
@@ -239,7 +230,7 @@ function build_cpu($gen_arch) {
|
|||||||
function build_cpu_avx() {
|
function build_cpu_avx() {
|
||||||
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx"))) {
|
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx"))) {
|
||||||
init_vars
|
init_vars
|
||||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DGGML_AVX=on", "-DGGML_AVX2=off", "-DGGML_AVX512=off", "-DGGML_FMA=off", "-DGGML_F16C=off") + $script:cmakeDefs
|
||||||
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
|
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
|
||||||
$script:distDir="$script:DIST_BASE\cpu_avx"
|
$script:distDir="$script:DIST_BASE\cpu_avx"
|
||||||
write-host "Building AVX CPU"
|
write-host "Building AVX CPU"
|
||||||
@@ -254,7 +245,7 @@ function build_cpu_avx() {
|
|||||||
function build_cpu_avx2() {
|
function build_cpu_avx2() {
|
||||||
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx2"))) {
|
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx2"))) {
|
||||||
init_vars
|
init_vars
|
||||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
|
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DGGML_AVX=on", "-DGGML_AVX2=on", "-DGGML_AVX512=off", "-DGGML_FMA=on", "-DGGML_F16C=on") + $script:cmakeDefs
|
||||||
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
|
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
|
||||||
$script:distDir="$script:DIST_BASE\cpu_avx2"
|
$script:distDir="$script:DIST_BASE\cpu_avx2"
|
||||||
write-host "Building AVX2 CPU"
|
write-host "Building AVX2 CPU"
|
||||||
@@ -279,9 +270,9 @@ function build_cuda() {
|
|||||||
$script:distDir="$script:DIST_BASE\cuda$script:CUDA_VARIANT"
|
$script:distDir="$script:DIST_BASE\cuda$script:CUDA_VARIANT"
|
||||||
$script:cmakeDefs += @(
|
$script:cmakeDefs += @(
|
||||||
"-A", "x64",
|
"-A", "x64",
|
||||||
"-DLLAMA_CUDA=ON",
|
"-DGGML_CUDA=ON",
|
||||||
"-DLLAMA_AVX=on",
|
"-DGGML_AVX=on",
|
||||||
"-DLLAMA_AVX2=off",
|
"-DGGML_AVX2=off",
|
||||||
"-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR",
|
"-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR",
|
||||||
"-DCMAKE_CUDA_FLAGS=-t8",
|
"-DCMAKE_CUDA_FLAGS=-t8",
|
||||||
"-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}"
|
"-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}"
|
||||||
@@ -319,7 +310,7 @@ function build_oneapi() {
|
|||||||
$script:distDir ="$script:DIST_BASE\oneapi$script:ONEAPI_VARIANT"
|
$script:distDir ="$script:DIST_BASE\oneapi$script:ONEAPI_VARIANT"
|
||||||
$script:cmakeDefs += @(
|
$script:cmakeDefs += @(
|
||||||
"-G", "MinGW Makefiles",
|
"-G", "MinGW Makefiles",
|
||||||
"-DLLAMA_SYCL=ON",
|
"-DGGML_SYCL=ON",
|
||||||
"-DCMAKE_C_COMPILER=icx",
|
"-DCMAKE_C_COMPILER=icx",
|
||||||
"-DCMAKE_CXX_COMPILER=icx",
|
"-DCMAKE_CXX_COMPILER=icx",
|
||||||
"-DCMAKE_BUILD_TYPE=Release"
|
"-DCMAKE_BUILD_TYPE=Release"
|
||||||
@@ -365,10 +356,11 @@ function build_rocm() {
|
|||||||
"-G", "Ninja",
|
"-G", "Ninja",
|
||||||
"-DCMAKE_C_COMPILER=clang.exe",
|
"-DCMAKE_C_COMPILER=clang.exe",
|
||||||
"-DCMAKE_CXX_COMPILER=clang++.exe",
|
"-DCMAKE_CXX_COMPILER=clang++.exe",
|
||||||
"-DLLAMA_HIPBLAS=on",
|
"-DGGML_HIPBLAS=on",
|
||||||
|
"-DLLAMA_CUDA_NO_PEER_COPY=on",
|
||||||
"-DHIP_PLATFORM=amd",
|
"-DHIP_PLATFORM=amd",
|
||||||
"-DLLAMA_AVX=on",
|
"-DGGML_AVX=on",
|
||||||
"-DLLAMA_AVX2=off",
|
"-DGGML_AVX2=off",
|
||||||
"-DCMAKE_POSITION_INDEPENDENT_CODE=on",
|
"-DCMAKE_POSITION_INDEPENDENT_CODE=on",
|
||||||
"-DAMDGPU_TARGETS=$(amdGPUs)",
|
"-DAMDGPU_TARGETS=$(amdGPUs)",
|
||||||
"-DGPU_TARGETS=$(amdGPUs)"
|
"-DGPU_TARGETS=$(amdGPUs)"
|
||||||
@@ -394,7 +386,6 @@ function build_rocm() {
|
|||||||
sign
|
sign
|
||||||
install
|
install
|
||||||
|
|
||||||
# Assumes v5.7, may need adjustments for v6
|
|
||||||
rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
|
rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
|
||||||
md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
|
md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
|
||||||
cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
|
cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
|
||||||
|
26
llm/ggml.go
26
llm/ggml.go
@@ -424,6 +424,32 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
|||||||
4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
|
4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
|
||||||
4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
|
4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
|
||||||
)
|
)
|
||||||
|
case "chatglm":
|
||||||
|
fullOffload = 4 * batch * (embedding + vocab)
|
||||||
|
partialOffload = 4*batch*(embedding+vocab) + embedding*vocab*105/128
|
||||||
|
if qkvBias, ok := layers["blk.0"]["attn_qkv.bias"]; ok {
|
||||||
|
fullOffload = max(
|
||||||
|
fullOffload,
|
||||||
|
4*batch*(2+
|
||||||
|
2*embedding+
|
||||||
|
context+
|
||||||
|
context*heads+
|
||||||
|
embeddingHeadsK*heads+
|
||||||
|
qkvBias.Shape[0]),
|
||||||
|
)
|
||||||
|
|
||||||
|
partialOffload = max(
|
||||||
|
partialOffload,
|
||||||
|
4*batch*(1+
|
||||||
|
2*embedding+
|
||||||
|
embeddingHeadsK*heads+
|
||||||
|
context+
|
||||||
|
context*heads)+
|
||||||
|
4*embeddingHeadsK*context+
|
||||||
|
4*context*embeddingHeadsK+
|
||||||
|
4*qkvBias.Shape[0],
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@@ -537,6 +537,7 @@ var ggufKVOrder = map[string][]string{
|
|||||||
"tokenizer.ggml.add_bos_token",
|
"tokenizer.ggml.add_bos_token",
|
||||||
"tokenizer.ggml.add_eos_token",
|
"tokenizer.ggml.add_eos_token",
|
||||||
"tokenizer.chat_template",
|
"tokenizer.chat_template",
|
||||||
|
"bert.pooling_type",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Submodule llm/llama.cpp updated: 7c26775adb...a8db2a9ce6
17
llm/llm.go
17
llm/llm.go
@@ -1,12 +1,13 @@
|
|||||||
package llm
|
package llm
|
||||||
|
|
||||||
// #cgo CFLAGS: -Illama.cpp
|
// #cgo CFLAGS: -Illama.cpp -Illama.cpp/include -Illama.cpp/ggml/include
|
||||||
// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++
|
// #cgo LDFLAGS: -lllama -lggml -lstdc++ -lpthread
|
||||||
// #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++
|
// #cgo darwin,arm64 LDFLAGS: -L${SRCDIR}/build/darwin/arm64_static -L${SRCDIR}/build/darwin/arm64_static/src -L${SRCDIR}/build/darwin/arm64_static/ggml/src -framework Accelerate -framework Metal
|
||||||
// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
|
// #cgo darwin,amd64 LDFLAGS: -L${SRCDIR}/build/darwin/x86_64_static -L${SRCDIR}/build/darwin/x86_64_static/src -L${SRCDIR}/build/darwin/x86_64_static/ggml/src
|
||||||
// #cgo windows,arm64 LDFLAGS: ${SRCDIR}/build/windows/arm64_static/libllama.a -static -lstdc++
|
// #cgo windows,amd64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/amd64_static -L${SRCDIR}/build/windows/amd64_static/src -L${SRCDIR}/build/windows/amd64_static/ggml/src
|
||||||
// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
|
// #cgo windows,arm64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/arm64_static -L${SRCDIR}/build/windows/arm64_static/src -L${SRCDIR}/build/windows/arm64_static/ggml/src
|
||||||
// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
|
// #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src
|
||||||
|
// #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
|
||||||
// #include <stdlib.h>
|
// #include <stdlib.h>
|
||||||
// #include "llama.h"
|
// #include "llama.h"
|
||||||
import "C"
|
import "C"
|
||||||
@@ -32,7 +33,7 @@ func Quantize(infile, outfile string, ftype fileType) error {
|
|||||||
params.ftype = ftype.Value()
|
params.ftype = ftype.Value()
|
||||||
|
|
||||||
if rc := C.llama_model_quantize(cinfile, coutfile, ¶ms); rc != 0 {
|
if rc := C.llama_model_quantize(cinfile, coutfile, ¶ms); rc != 0 {
|
||||||
return fmt.Errorf("llama_model_quantize: %d", rc)
|
return fmt.Errorf("failed to quantize model. This model architecture may not be supported, or you may need to upgrade Ollama to the latest version")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
diff --git a/common/common.cpp b/common/common.cpp
|
diff --git a/common/common.cpp b/common/common.cpp
|
||||||
index 73ff0e85..6adb1a92 100644
|
index 2c05a4d4..927f0e3d 100644
|
||||||
--- a/common/common.cpp
|
--- a/common/common.cpp
|
||||||
+++ b/common/common.cpp
|
+++ b/common/common.cpp
|
||||||
@@ -2447,6 +2447,8 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
|
@@ -2093,6 +2093,8 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
|
||||||
mparams.use_mmap = params.use_mmap;
|
mparams.use_mmap = params.use_mmap;
|
||||||
mparams.use_mlock = params.use_mlock;
|
mparams.use_mlock = params.use_mlock;
|
||||||
mparams.check_tensors = params.check_tensors;
|
mparams.check_tensors = params.check_tensors;
|
||||||
@@ -12,10 +12,10 @@ index 73ff0e85..6adb1a92 100644
|
|||||||
mparams.kv_overrides = NULL;
|
mparams.kv_overrides = NULL;
|
||||||
} else {
|
} else {
|
||||||
diff --git a/common/common.h b/common/common.h
|
diff --git a/common/common.h b/common/common.h
|
||||||
index 58ed72f4..0bb2605e 100644
|
index 65c0ef81..ebca2c77 100644
|
||||||
--- a/common/common.h
|
--- a/common/common.h
|
||||||
+++ b/common/common.h
|
+++ b/common/common.h
|
||||||
@@ -180,6 +180,13 @@ struct gpt_params {
|
@@ -184,6 +184,13 @@ struct gpt_params {
|
||||||
std::string mmproj = ""; // path to multimodal projector
|
std::string mmproj = ""; // path to multimodal projector
|
||||||
std::vector<std::string> image; // path to image file(s)
|
std::vector<std::string> image; // path to image file(s)
|
||||||
|
|
||||||
@@ -26,6 +26,6 @@ index 58ed72f4..0bb2605e 100644
|
|||||||
+ // context pointer passed to the progress callback
|
+ // context pointer passed to the progress callback
|
||||||
+ void * progress_callback_user_data;
|
+ void * progress_callback_user_data;
|
||||||
+
|
+
|
||||||
// server params
|
// embedding
|
||||||
int32_t port = 8080; // server listens on this network port
|
bool embedding = false; // get only sentence embedding
|
||||||
int32_t timeout_read = 600; // http read timeout in seconds
|
int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
||||||
|
@@ -1,17 +1,8 @@
|
|||||||
From 544a2d2e646d39e878d87dfbb3398a356bc560ab Mon Sep 17 00:00:00 2001
|
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||||
From: Michael Yang <mxyng@pm.me>
|
index 73f52435..58a00fb1 100644
|
||||||
Date: Thu, 23 May 2024 11:18:45 -0700
|
--- a/src/llama.cpp
|
||||||
Subject: [PATCH] throw exception on load errors
|
+++ b/src/llama.cpp
|
||||||
|
@@ -7241,7 +7241,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
|
||||||
---
|
|
||||||
llama.cpp | 25 ++++++++++++++++---------
|
|
||||||
1 file changed, 16 insertions(+), 9 deletions(-)
|
|
||||||
|
|
||||||
diff --git a/llama.cpp b/llama.cpp
|
|
||||||
index 15c66077..8ba90b6a 100644
|
|
||||||
--- a/llama.cpp
|
|
||||||
+++ b/llama.cpp
|
|
||||||
@@ -6346,7 +6346,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
|
|
||||||
}
|
}
|
||||||
} catch (const std::exception & err) {
|
} catch (const std::exception & err) {
|
||||||
LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
|
LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
|
||||||
@@ -20,7 +11,7 @@ index 15c66077..8ba90b6a 100644
|
|||||||
}
|
}
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
@@ -15600,16 +15600,23 @@ struct llama_model * llama_load_model_from_file(
|
@@ -17564,16 +17564,23 @@ struct llama_model * llama_load_model_from_file(
|
||||||
}
|
}
|
||||||
model->rpc_servers.push_back(servers);
|
model->rpc_servers.push_back(servers);
|
||||||
}
|
}
|
||||||
@@ -52,6 +43,3 @@ index 15c66077..8ba90b6a 100644
|
|||||||
}
|
}
|
||||||
|
|
||||||
return model;
|
return model;
|
||||||
--
|
|
||||||
2.45.1
|
|
||||||
|
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
diff --git a/ggml-metal.m b/ggml-metal.m
|
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
|
||||||
index 0207b787..b5e9884b 100644
|
index 0207b787..b5e9884b 100644
|
||||||
--- a/ggml-metal.m
|
--- a/ggml/src/ggml-metal.m
|
||||||
+++ b/ggml-metal.m
|
+++ b/ggml/src/ggml-metal.m
|
||||||
@@ -1396,27 +1396,23 @@ static enum ggml_status ggml_metal_graph_compute(
|
@@ -1396,27 +1396,23 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
// to the matrix-vector kernel
|
// to the matrix-vector kernel
|
||||||
int ne11_mm_min = 1;
|
int ne11_mm_min = 1;
|
||||||
|
@@ -1,11 +1,11 @@
|
|||||||
diff --git a/llama.cpp b/llama.cpp
|
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||||
index 61948751..4b72a293 100644
|
index 2b9ace28..172640e2 100644
|
||||||
--- a/llama.cpp
|
--- a/src/llama.cpp
|
||||||
+++ b/llama.cpp
|
+++ b/src/llama.cpp
|
||||||
@@ -4824,16 +4824,7 @@ static void llm_load_vocab(
|
@@ -5357,16 +5357,7 @@ static void llm_load_vocab(
|
||||||
|
|
||||||
// for now, only BPE models have pre-tokenizers
|
|
||||||
if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
|
if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
|
||||||
|
vocab.tokenizer_add_space_prefix = false;
|
||||||
|
vocab.tokenizer_clean_spaces = true;
|
||||||
- if (tokenizer_pre.empty()) {
|
- if (tokenizer_pre.empty()) {
|
||||||
- LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
|
- LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
|
||||||
- LLAMA_LOG_WARN("%s: \n", __func__);
|
- LLAMA_LOG_WARN("%s: \n", __func__);
|
||||||
@@ -20,13 +20,13 @@ index 61948751..4b72a293 100644
|
|||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
} else if (
|
} else if (
|
||||||
tokenizer_pre == "llama3" ||
|
tokenizer_pre == "llama3" ||
|
||||||
@@ -4888,7 +4879,8 @@ static void llm_load_vocab(
|
@@ -5439,7 +5430,8 @@ static void llm_load_vocab(
|
||||||
tokenizer_pre == "poro-chat") {
|
tokenizer_pre == "jais") {
|
||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO;
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
|
||||||
} else {
|
} else {
|
||||||
- throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
- throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||||
+ LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
|
+ LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
|
||||||
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
}
|
}
|
||||||
} else {
|
} else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
|
||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
diff --git a/llama.cpp b/llama.cpp
|
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||||
index 40d2ec2c..f34eb79a 100644
|
index 40d2ec2c..f34eb79a 100644
|
||||||
--- a/llama.cpp
|
--- a/src/llama.cpp
|
||||||
+++ b/llama.cpp
|
+++ b/src/llama.cpp
|
||||||
@@ -6943,7 +6943,7 @@ static struct ggml_tensor * llm_build_kqv(
|
@@ -6943,7 +6943,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||||
cb(kq, "kq", il);
|
cb(kq, "kq", il);
|
||||||
|
45
llm/patches/07-embeddings.diff
Normal file
45
llm/patches/07-embeddings.diff
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||||
|
index 1fe2b9f7..a43312a7 100644
|
||||||
|
--- a/src/llama.cpp
|
||||||
|
+++ b/src/llama.cpp
|
||||||
|
@@ -13689,7 +13689,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
|
||||||
|
const auto n_embd = hparams.n_embd;
|
||||||
|
|
||||||
|
// TODO: use a per-batch flag for logits presence instead
|
||||||
|
- const bool has_logits = !cparams.embeddings;
|
||||||
|
+ const bool has_logits = cparams.causal_attn;
|
||||||
|
const bool has_embd = lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE));
|
||||||
|
|
||||||
|
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
|
||||||
|
@@ -13959,17 +13959,25 @@ static int llama_decode_internal(
|
||||||
|
// no output
|
||||||
|
res = nullptr;
|
||||||
|
embd = nullptr;
|
||||||
|
- } else if (cparams.embeddings) {
|
||||||
|
- res = nullptr; // do not extract logits for embedding case
|
||||||
|
- embd = gf->nodes[gf->n_nodes - 1];
|
||||||
|
- if (strcmp(embd->name, "result_embd_pooled") != 0) {
|
||||||
|
- embd = gf->nodes[gf->n_nodes - 2];
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ if (cparams.embeddings) {
|
||||||
|
+ for (int i = gf->n_nodes - 1; i >= 0; --i) {
|
||||||
|
+ embd = gf->nodes[i];
|
||||||
|
+ if (strcmp(embd->name, "result_embd_pooled") == 0) {
|
||||||
|
+ break;
|
||||||
|
+ }
|
||||||
|
}
|
||||||
|
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
|
||||||
|
- } else {
|
||||||
|
+ } else {
|
||||||
|
embd = nullptr; // do not extract embeddings when not needed
|
||||||
|
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
|
||||||
|
}
|
||||||
|
+
|
||||||
|
+ if (!cparams.causal_attn) {
|
||||||
|
+ res = nullptr; // do not extract logits when not needed
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||||
|
|
||||||
|
ggml_backend_sched_alloc_graph(lctx.sched, gf);
|
@@ -1,305 +0,0 @@
|
|||||||
From 5cadb45f39d001ffbad95b690d6cf0abcb4a6d96 Mon Sep 17 00:00:00 2001
|
|
||||||
From: Ollama maintainers <hello@ollama.com>
|
|
||||||
Date: Wed, 26 Jun 2024 16:18:09 -0700
|
|
||||||
Subject: [PATCH] Architecture support
|
|
||||||
|
|
||||||
---
|
|
||||||
llama.cpp | 194 +++++++++++++++++++++++++++++++++++++++++++++++++++++-
|
|
||||||
1 file changed, 193 insertions(+), 1 deletion(-)
|
|
||||||
|
|
||||||
diff --git a/llama.cpp b/llama.cpp
|
|
||||||
index 61948751..3b4196f5 100644
|
|
||||||
--- a/llama.cpp
|
|
||||||
+++ b/llama.cpp
|
|
||||||
@@ -217,6 +217,7 @@ enum llm_arch {
|
|
||||||
LLM_ARCH_INTERNLM2,
|
|
||||||
LLM_ARCH_MINICPM,
|
|
||||||
LLM_ARCH_GEMMA,
|
|
||||||
+ LLM_ARCH_GEMMA2,
|
|
||||||
LLM_ARCH_STARCODER2,
|
|
||||||
LLM_ARCH_MAMBA,
|
|
||||||
LLM_ARCH_XVERSE,
|
|
||||||
@@ -255,6 +256,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|
||||||
{ LLM_ARCH_INTERNLM2, "internlm2" },
|
|
||||||
{ LLM_ARCH_MINICPM, "minicpm" },
|
|
||||||
{ LLM_ARCH_GEMMA, "gemma" },
|
|
||||||
+ { LLM_ARCH_GEMMA2, "gemma2" },
|
|
||||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
|
||||||
{ LLM_ARCH_MAMBA, "mamba" },
|
|
||||||
{ LLM_ARCH_XVERSE, "xverse" },
|
|
||||||
@@ -464,10 +466,12 @@ enum llm_tensor {
|
|
||||||
LLM_TENSOR_ATTN_NORM,
|
|
||||||
LLM_TENSOR_ATTN_NORM_2,
|
|
||||||
LLM_TENSOR_ATTN_OUT_NORM,
|
|
||||||
+ LLM_TENSOR_ATTN_POST_NORM,
|
|
||||||
LLM_TENSOR_ATTN_ROT_EMBD,
|
|
||||||
LLM_TENSOR_FFN_GATE_INP,
|
|
||||||
LLM_TENSOR_FFN_GATE_INP_SHEXP,
|
|
||||||
LLM_TENSOR_FFN_NORM,
|
|
||||||
+ LLM_TENSOR_FFN_POST_NORM,
|
|
||||||
LLM_TENSOR_FFN_GATE,
|
|
||||||
LLM_TENSOR_FFN_DOWN,
|
|
||||||
LLM_TENSOR_FFN_UP,
|
|
||||||
@@ -960,6 +964,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
|
||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
|
||||||
},
|
|
||||||
},
|
|
||||||
+ {
|
|
||||||
+ LLM_ARCH_GEMMA2,
|
|
||||||
+ {
|
|
||||||
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
|
||||||
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
|
||||||
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
|
||||||
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
|
||||||
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
|
||||||
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
|
||||||
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
|
||||||
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
|
||||||
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
|
||||||
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
|
||||||
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
|
||||||
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
|
||||||
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
|
||||||
+ },
|
|
||||||
+ },
|
|
||||||
{
|
|
||||||
LLM_ARCH_STARCODER2,
|
|
||||||
{
|
|
||||||
@@ -1941,6 +1963,8 @@ enum e_model {
|
|
||||||
MODEL_8x22B,
|
|
||||||
MODEL_16x12B,
|
|
||||||
MODEL_10B_128x3_66B,
|
|
||||||
+ MODEL_9B,
|
|
||||||
+ MODEL_27B,
|
|
||||||
};
|
|
||||||
|
|
||||||
static const size_t kiB = 1024;
|
|
||||||
@@ -2114,6 +2138,7 @@ struct llama_layer {
|
|
||||||
struct ggml_tensor * attn_out_norm_b;
|
|
||||||
struct ggml_tensor * attn_q_a_norm;
|
|
||||||
struct ggml_tensor * attn_kv_a_norm;
|
|
||||||
+ struct ggml_tensor * attn_post_norm;
|
|
||||||
|
|
||||||
// attention
|
|
||||||
struct ggml_tensor * wq;
|
|
||||||
@@ -2136,6 +2161,7 @@ struct llama_layer {
|
|
||||||
// normalization
|
|
||||||
struct ggml_tensor * ffn_norm;
|
|
||||||
struct ggml_tensor * ffn_norm_b;
|
|
||||||
+ struct ggml_tensor * ffn_post_norm;
|
|
||||||
struct ggml_tensor * layer_out_norm;
|
|
||||||
struct ggml_tensor * layer_out_norm_b;
|
|
||||||
struct ggml_tensor * ffn_norm_exps;
|
|
||||||
@@ -4529,6 +4555,16 @@ static void llm_load_hparams(
|
|
||||||
}
|
|
||||||
} break;
|
|
||||||
case LLM_ARCH_GEMMA:
|
|
||||||
+ {
|
|
||||||
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
|
||||||
+
|
|
||||||
+ switch (hparams.n_layer) {
|
|
||||||
+ case 18: model.type = e_model::MODEL_9B; break;
|
|
||||||
+ case 28: model.type = e_model::MODEL_27B; break;
|
|
||||||
+ default: model.type = e_model::MODEL_UNKNOWN;
|
|
||||||
+ }
|
|
||||||
+ } break;
|
|
||||||
+ case LLM_ARCH_GEMMA2:
|
|
||||||
{
|
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
|
||||||
|
|
||||||
@@ -6305,6 +6341,40 @@ static bool llm_load_tensors(
|
|
||||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
|
||||||
}
|
|
||||||
} break;
|
|
||||||
+ case LLM_ARCH_GEMMA2:
|
|
||||||
+ {
|
|
||||||
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
|
||||||
+
|
|
||||||
+ // output
|
|
||||||
+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
|
||||||
+ model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
|
|
||||||
+
|
|
||||||
+ const int64_t n_ff = hparams.n_ff;
|
|
||||||
+ const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
|
||||||
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
|
||||||
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
|
||||||
+
|
|
||||||
+ for (uint32_t i = 0; i < n_layer; ++i) {
|
|
||||||
+ ggml_context * ctx_layer = ctx_for_layer(i);
|
|
||||||
+ ggml_context * ctx_split = ctx_for_layer_split(i);
|
|
||||||
+
|
|
||||||
+ auto & layer = model.layers[i];
|
|
||||||
+
|
|
||||||
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
|
||||||
+
|
|
||||||
+ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * hparams.n_head});
|
|
||||||
+ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
|
|
||||||
+ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
|
|
||||||
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * hparams.n_head, n_embd});
|
|
||||||
+ layer.attn_post_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd});
|
|
||||||
+
|
|
||||||
+ layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
|
||||||
+ layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
|
||||||
+ layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
|
||||||
+ layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
|
||||||
+ layer.ffn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd});
|
|
||||||
+ }
|
|
||||||
+ } break;
|
|
||||||
case LLM_ARCH_STARCODER2:
|
|
||||||
{
|
|
||||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
|
||||||
@@ -10614,6 +10684,123 @@ struct llm_build_context {
|
|
||||||
return gf;
|
|
||||||
}
|
|
||||||
|
|
||||||
+ struct ggml_cgraph * build_gemma2() {
|
|
||||||
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
|
||||||
+
|
|
||||||
+ const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
|
||||||
+
|
|
||||||
+ struct ggml_tensor * cur;
|
|
||||||
+ struct ggml_tensor * inpL;
|
|
||||||
+
|
|
||||||
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
|
||||||
+
|
|
||||||
+ inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
|
|
||||||
+ cb(inpL, "inp_scaled", -1);
|
|
||||||
+
|
|
||||||
+ // inp_pos - contains the positions
|
|
||||||
+ struct ggml_tensor * inp_pos = build_inp_pos();
|
|
||||||
+
|
|
||||||
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
|
||||||
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
|
||||||
+
|
|
||||||
+ for (int il = 0; il < n_layer; ++il) {
|
|
||||||
+ // norm
|
|
||||||
+ cur = llm_build_norm(ctx0, inpL, hparams,
|
|
||||||
+ model.layers[il].attn_norm, NULL,
|
|
||||||
+ LLM_NORM_RMS, cb, il);
|
|
||||||
+ cb(cur, "attn_norm", il);
|
|
||||||
+
|
|
||||||
+ // self-attention
|
|
||||||
+ {
|
|
||||||
+ // compute Q and K and RoPE them
|
|
||||||
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
|
||||||
+ cb(Qcur, "Qcur", il);
|
|
||||||
+
|
|
||||||
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
|
||||||
+ cb(Kcur, "Kcur", il);
|
|
||||||
+
|
|
||||||
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
|
||||||
+ cb(Vcur, "Vcur", il);
|
|
||||||
+
|
|
||||||
+ Qcur = ggml_rope_ext(
|
|
||||||
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
|
|
||||||
+ n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
||||||
+ ext_factor, attn_factor, beta_fast, beta_slow);
|
|
||||||
+ cb(Qcur, "Qcur", il);
|
|
||||||
+
|
|
||||||
+ Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
|
|
||||||
+ cb(Qcur, "Qcur_scaled", il);
|
|
||||||
+
|
|
||||||
+ Kcur = ggml_rope_ext(
|
|
||||||
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
|
|
||||||
+ n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
||||||
+ ext_factor, attn_factor, beta_fast, beta_slow);
|
|
||||||
+ cb(Kcur, "Kcur", il);
|
|
||||||
+
|
|
||||||
+ cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
|
||||||
+ model.layers[il].wo, NULL,
|
|
||||||
+ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ if (il == n_layer - 1) {
|
|
||||||
+ // skip computing output for unused tokens
|
|
||||||
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
||||||
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
||||||
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ cur = llm_build_norm(ctx0, cur, hparams,
|
|
||||||
+ model.layers[il].attn_post_norm, NULL,
|
|
||||||
+ LLM_NORM_RMS, cb, il);
|
|
||||||
+ cb(cur, "attn_post_norm", il);
|
|
||||||
+
|
|
||||||
+ struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
|
|
||||||
+ cb(sa_out, "sa_out", il);
|
|
||||||
+
|
|
||||||
+ cur = llm_build_norm(ctx0, sa_out, hparams,
|
|
||||||
+ model.layers[il].ffn_norm, NULL,
|
|
||||||
+ LLM_NORM_RMS, cb, il);
|
|
||||||
+ cb(cur, "ffn_norm", il);
|
|
||||||
+
|
|
||||||
+ // feed-forward network
|
|
||||||
+ {
|
|
||||||
+ cur = llm_build_ffn(ctx0, cur,
|
|
||||||
+ model.layers[il].ffn_up, NULL,
|
|
||||||
+ model.layers[il].ffn_gate, NULL,
|
|
||||||
+ model.layers[il].ffn_down, NULL,
|
|
||||||
+ NULL,
|
|
||||||
+ LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
|
|
||||||
+ cb(cur, "ffn_out", il);
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ cur = llm_build_norm(ctx0, cur, hparams,
|
|
||||||
+ model.layers[il].ffn_post_norm, NULL,
|
|
||||||
+ LLM_NORM_RMS, cb, -1);
|
|
||||||
+ cb(cur, "ffn_post_norm", -1);
|
|
||||||
+
|
|
||||||
+ cur = ggml_add(ctx0, cur, sa_out);
|
|
||||||
+ cb(cur, "l_out", il);
|
|
||||||
+
|
|
||||||
+ // input for next layer
|
|
||||||
+ inpL = cur;
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ cur = inpL;
|
|
||||||
+
|
|
||||||
+ cur = llm_build_norm(ctx0, cur, hparams,
|
|
||||||
+ model.output_norm, NULL,
|
|
||||||
+ LLM_NORM_RMS, cb, -1);
|
|
||||||
+ cb(cur, "result_norm", -1);
|
|
||||||
+
|
|
||||||
+ // lm_head
|
|
||||||
+ cur = ggml_mul_mat(ctx0, model.output, cur);
|
|
||||||
+ cb(cur, "result_output", -1);
|
|
||||||
+
|
|
||||||
+ ggml_build_forward_expand(gf, cur);
|
|
||||||
+
|
|
||||||
+ return gf;
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
struct ggml_cgraph * build_starcoder2() {
|
|
||||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
|
||||||
|
|
||||||
@@ -11847,6 +12034,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|
||||||
{
|
|
||||||
result = llm.build_gemma();
|
|
||||||
} break;
|
|
||||||
+ case LLM_ARCH_GEMMA2:
|
|
||||||
+ {
|
|
||||||
+ result = llm.build_gemma2();
|
|
||||||
+ } break;
|
|
||||||
case LLM_ARCH_STARCODER2:
|
|
||||||
{
|
|
||||||
result = llm.build_starcoder2();
|
|
||||||
@@ -16671,6 +16862,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|
||||||
case LLM_ARCH_PHI2:
|
|
||||||
case LLM_ARCH_PHI3:
|
|
||||||
case LLM_ARCH_GEMMA:
|
|
||||||
+ case LLM_ARCH_GEMMA2:
|
|
||||||
case LLM_ARCH_STARCODER2:
|
|
||||||
case LLM_ARCH_GPTNEOX:
|
|
||||||
return LLAMA_ROPE_TYPE_NEOX;
|
|
||||||
@@ -18551,7 +18743,7 @@ static int32_t llama_chat_apply_template_internal(
|
|
||||||
if (add_ass) {
|
|
||||||
ss << "<s>assistant\n";
|
|
||||||
}
|
|
||||||
- } else if (tmpl == "gemma" || tmpl.find("<start_of_turn>") != std::string::npos) {
|
|
||||||
+ } else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl.find("<start_of_turn>") != std::string::npos) {
|
|
||||||
// google/gemma-7b-it
|
|
||||||
std::string system_prompt = "";
|
|
||||||
for (auto message : chat) {
|
|
||||||
--
|
|
||||||
2.45.2
|
|
||||||
|
|
42
llm/patches/08-clip-unicode.diff
Normal file
42
llm/patches/08-clip-unicode.diff
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
|
||||||
|
index 95fbe3d0..5a02a6ec 100644
|
||||||
|
--- a/examples/llava/clip.cpp
|
||||||
|
+++ b/examples/llava/clip.cpp
|
||||||
|
@@ -32,6 +33,14 @@
|
||||||
|
#include <cinttypes>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
+#if defined(_WIN32)
|
||||||
|
+#define WIN32_LEAN_AND_MEAN
|
||||||
|
+#ifndef NOMINMAX
|
||||||
|
+ #define NOMINMAX
|
||||||
|
+#endif
|
||||||
|
+#include <windows.h>
|
||||||
|
+#endif
|
||||||
|
+
|
||||||
|
//#define CLIP_DEBUG_FUNCTIONS
|
||||||
|
|
||||||
|
// RGB uint8 image
|
||||||
|
@@ -1055,7 +1064,22 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
+#ifdef _WIN32
|
||||||
|
+ int wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, NULL, 0);
|
||||||
|
+ if (!wlen) {
|
||||||
|
+ return NULL;
|
||||||
|
+ }
|
||||||
|
+ wchar_t * wbuf = (wchar_t *) malloc(wlen * sizeof(wchar_t));
|
||||||
|
+ wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, wbuf, wlen);
|
||||||
|
+ if (!wlen) {
|
||||||
|
+ free(wbuf);
|
||||||
|
+ return NULL;
|
||||||
|
+ }
|
||||||
|
+ auto fin = std::ifstream(wbuf, std::ios::binary);
|
||||||
|
+ free(wbuf);
|
||||||
|
+#else
|
||||||
|
auto fin = std::ifstream(fname, std::ios::binary);
|
||||||
|
+#endif
|
||||||
|
if (!fin) {
|
||||||
|
LOG_TEE("cannot open model file for loading tensors\n");
|
||||||
|
clip_free(new_clip);
|
60
llm/patches/09-pooling.diff
Normal file
60
llm/patches/09-pooling.diff
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||||
|
index 721b8f4e..cfe7ac40 100644
|
||||||
|
--- a/src/llama.cpp
|
||||||
|
+++ b/src/llama.cpp
|
||||||
|
@@ -8420,14 +8420,14 @@ struct llm_build_context {
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * build_inp_mean() {
|
||||||
|
- lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
|
||||||
|
+ lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, cparams.n_seq_max);
|
||||||
|
cb(lctx.inp_mean, "inp_mean", -1);
|
||||||
|
ggml_set_input(lctx.inp_mean);
|
||||||
|
return lctx.inp_mean;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * build_inp_cls() {
|
||||||
|
- lctx.inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||||
|
+ lctx.inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_seq_max);
|
||||||
|
cb(lctx.inp_cls, "inp_cls", -1);
|
||||||
|
ggml_set_input(lctx.inp_cls);
|
||||||
|
return lctx.inp_cls;
|
||||||
|
@@ -13847,19 +13847,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
|
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
|
||||||
|
|
||||||
|
float * data = (float *) lctx.inp_mean->data;
|
||||||
|
- memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
|
||||||
|
+ memset(lctx.inp_mean->data, 0, n_tokens * cparams.n_seq_max * ggml_element_size(lctx.inp_mean));
|
||||||
|
|
||||||
|
std::vector<uint64_t> sum(n_tokens, 0);
|
||||||
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
|
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||||
|
-
|
||||||
|
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
|
||||||
|
-
|
||||||
|
sum[seq_id] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
- std::vector<float> div(n_tokens, 0.0f);
|
||||||
|
- for (int i = 0; i < n_tokens; ++i) {
|
||||||
|
+ std::vector<float> div(cparams.n_seq_max, 0.0f);
|
||||||
|
+ for (uint32_t i = 0; i < cparams.n_seq_max; ++i) {
|
||||||
|
const uint64_t s = sum[i];
|
||||||
|
if (s > 0) {
|
||||||
|
div[i] = 1.0f/float(s);
|
||||||
|
@@ -13879,14 +13876,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
|
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
|
||||||
|
|
||||||
|
uint32_t * data = (uint32_t *) lctx.inp_cls->data;
|
||||||
|
- memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
|
||||||
|
+ memset(lctx.inp_cls->data, 0, cparams.n_seq_max * ggml_element_size(lctx.inp_cls));
|
||||||
|
|
||||||
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
|
const llama_seq_id seq_id = batch.seq_id[i][0];
|
||||||
|
const llama_pos pos = batch.pos[i];
|
||||||
|
-
|
||||||
|
- GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
|
||||||
|
-
|
||||||
|
if (pos == 0) {
|
||||||
|
data[seq_id] = i;
|
||||||
|
}
|
@@ -38,7 +38,7 @@ func Init() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var variants []string
|
var variants []string
|
||||||
for v := range availableServers() {
|
for v := range getAvailableServers() {
|
||||||
variants = append(variants, v)
|
variants = append(variants, v)
|
||||||
}
|
}
|
||||||
slog.Info(fmt.Sprintf("Dynamic LLM libraries %v", variants))
|
slog.Info(fmt.Sprintf("Dynamic LLM libraries %v", variants))
|
||||||
@@ -50,7 +50,7 @@ func Init() error {
|
|||||||
// binary names may contain an optional variant separated by '_'
|
// binary names may contain an optional variant separated by '_'
|
||||||
// For example, "ollama_rocm_v6" and "ollama_rocm_v5" or "ollama_cpu" and "ollama_cpu_avx2"
|
// For example, "ollama_rocm_v6" and "ollama_rocm_v5" or "ollama_cpu" and "ollama_cpu_avx2"
|
||||||
// Any library without a variant is the lowest common denominator
|
// Any library without a variant is the lowest common denominator
|
||||||
func availableServers() map[string]string {
|
func getAvailableServers() map[string]string {
|
||||||
payloadsDir, err := gpu.PayloadsDir()
|
payloadsDir, err := gpu.PayloadsDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("payload lookup error", "error", err)
|
slog.Error("payload lookup error", "error", err)
|
||||||
@@ -80,7 +80,7 @@ func availableServers() map[string]string {
|
|||||||
// TODO - switch to metadata based mapping
|
// TODO - switch to metadata based mapping
|
||||||
func serversForGpu(info gpu.GpuInfo) []string {
|
func serversForGpu(info gpu.GpuInfo) []string {
|
||||||
// glob workDir for files that start with ollama_
|
// glob workDir for files that start with ollama_
|
||||||
availableServers := availableServers()
|
availableServers := getAvailableServers()
|
||||||
requested := info.Library
|
requested := info.Library
|
||||||
if info.Variant != gpu.CPUCapabilityNone {
|
if info.Variant != gpu.CPUCapabilityNone {
|
||||||
requested += "_" + info.Variant.String()
|
requested += "_" + info.Variant.String()
|
||||||
@@ -115,27 +115,29 @@ func serversForGpu(info gpu.GpuInfo) []string {
|
|||||||
servers = append(servers, alt...)
|
servers = append(servers, alt...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load up the best CPU variant if not primary requested
|
if !(runtime.GOOS == "darwin" && runtime.GOARCH == "arm64") {
|
||||||
if info.Library != "cpu" {
|
// Load up the best CPU variant if not primary requested
|
||||||
variant := gpu.GetCPUCapability()
|
if info.Library != "cpu" {
|
||||||
// If no variant, then we fall back to default
|
variant := gpu.GetCPUCapability()
|
||||||
// If we have a variant, try that if we find an exact match
|
// If no variant, then we fall back to default
|
||||||
// Attempting to run the wrong CPU instructions will panic the
|
// If we have a variant, try that if we find an exact match
|
||||||
// process
|
// Attempting to run the wrong CPU instructions will panic the
|
||||||
if variant != gpu.CPUCapabilityNone {
|
// process
|
||||||
for cmp := range availableServers {
|
if variant != gpu.CPUCapabilityNone {
|
||||||
if cmp == "cpu_"+variant.String() {
|
for cmp := range availableServers {
|
||||||
servers = append(servers, cmp)
|
if cmp == "cpu_"+variant.String() {
|
||||||
break
|
servers = append(servers, cmp)
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
servers = append(servers, "cpu")
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
servers = append(servers, "cpu")
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if len(servers) == 0 {
|
if len(servers) == 0 {
|
||||||
servers = []string{"cpu"}
|
servers = []string{"cpu"}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return servers
|
return servers
|
||||||
@@ -147,7 +149,7 @@ func serverForCpu() string {
|
|||||||
return "metal"
|
return "metal"
|
||||||
}
|
}
|
||||||
variant := gpu.GetCPUCapability()
|
variant := gpu.GetCPUCapability()
|
||||||
availableServers := availableServers()
|
availableServers := getAvailableServers()
|
||||||
if variant != gpu.CPUCapabilityNone {
|
if variant != gpu.CPUCapabilityNone {
|
||||||
for cmp := range availableServers {
|
for cmp := range availableServers {
|
||||||
if cmp == "cpu_"+variant.String() {
|
if cmp == "cpu_"+variant.String() {
|
||||||
|
@@ -33,7 +33,7 @@ type LlamaServer interface {
|
|||||||
Ping(ctx context.Context) error
|
Ping(ctx context.Context) error
|
||||||
WaitUntilRunning(ctx context.Context) error
|
WaitUntilRunning(ctx context.Context) error
|
||||||
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||||
Embedding(ctx context.Context, prompt string) ([]float64, error)
|
Embed(ctx context.Context, input []string) ([][]float32, error)
|
||||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||||
Close() error
|
Close() error
|
||||||
@@ -88,6 +88,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
var estimate MemoryEstimate
|
var estimate MemoryEstimate
|
||||||
var systemTotalMemory uint64
|
var systemTotalMemory uint64
|
||||||
var systemFreeMemory uint64
|
var systemFreeMemory uint64
|
||||||
|
var systemSwapFreeMemory uint64
|
||||||
|
|
||||||
systemMemInfo, err := gpu.GetCPUMem()
|
systemMemInfo, err := gpu.GetCPUMem()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -95,7 +96,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
} else {
|
} else {
|
||||||
systemTotalMemory = systemMemInfo.TotalMemory
|
systemTotalMemory = systemMemInfo.TotalMemory
|
||||||
systemFreeMemory = systemMemInfo.FreeMemory
|
systemFreeMemory = systemMemInfo.FreeMemory
|
||||||
slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", systemFreeMemory)
|
systemSwapFreeMemory = systemMemInfo.FreeSwap
|
||||||
|
slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
|
// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
|
||||||
@@ -122,6 +124,16 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// On linux, over-allocating CPU memory will almost always result in an error
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
|
||||||
|
available := systemFreeMemory + systemSwapFreeMemory
|
||||||
|
if systemMemoryRequired > available {
|
||||||
|
slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", available, "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "swap", format.HumanBytes2(systemSwapFreeMemory))
|
||||||
|
return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
estimate.log()
|
estimate.log()
|
||||||
|
|
||||||
// Loop through potential servers
|
// Loop through potential servers
|
||||||
@@ -131,7 +143,20 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
availableServers := availableServers()
|
availableServers := getAvailableServers()
|
||||||
|
if len(availableServers) == 0 {
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
slog.Warn("llama server binary disappeared, reinitializing payloads")
|
||||||
|
err = Init()
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to reinitialize payloads", "error", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
availableServers = getAvailableServers()
|
||||||
|
} else {
|
||||||
|
return nil, finalErr
|
||||||
|
}
|
||||||
|
}
|
||||||
var servers []string
|
var servers []string
|
||||||
if cpuRunner != "" {
|
if cpuRunner != "" {
|
||||||
servers = []string{cpuRunner}
|
servers = []string{cpuRunner}
|
||||||
@@ -208,7 +233,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
if g.Library == "metal" &&
|
if g.Library == "metal" &&
|
||||||
uint64(opts.NumGPU) > 0 &&
|
uint64(opts.NumGPU) > 0 &&
|
||||||
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
|
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
|
||||||
opts.UseMMap = api.TriStateFalse
|
opts.UseMMap = new(bool)
|
||||||
|
*opts.UseMMap = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -219,10 +245,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
// Windows CUDA should not use mmap for best performance
|
// Windows CUDA should not use mmap for best performance
|
||||||
// Linux with a model larger than free space, mmap leads to thrashing
|
// Linux with a model larger than free space, mmap leads to thrashing
|
||||||
// For CPU loads we want the memory to be allocated, not FS cache
|
// For CPU loads we want the memory to be allocated, not FS cache
|
||||||
if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == api.TriStateUndefined) ||
|
if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == nil) ||
|
||||||
(runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == api.TriStateUndefined) ||
|
(runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == nil) ||
|
||||||
(gpus[0].Library == "cpu" && opts.UseMMap == api.TriStateUndefined) ||
|
(gpus[0].Library == "cpu" && opts.UseMMap == nil) ||
|
||||||
opts.UseMMap == api.TriStateFalse {
|
(opts.UseMMap != nil && !*opts.UseMMap) {
|
||||||
params = append(params, "--no-mmap")
|
params = append(params, "--no-mmap")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -240,10 +266,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||||||
params = append(params, "--tensor-split", estimate.TensorSplit)
|
params = append(params, "--tensor-split", estimate.TensorSplit)
|
||||||
}
|
}
|
||||||
|
|
||||||
if estimate.TensorSplit != "" {
|
|
||||||
params = append(params, "--tensor-split", estimate.TensorSplit)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range len(servers) {
|
for i := range len(servers) {
|
||||||
dir := availableServers[servers[i]]
|
dir := availableServers[servers[i]]
|
||||||
if dir == "" {
|
if dir == "" {
|
||||||
@@ -560,6 +582,9 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
|||||||
if s.status != nil && s.status.LastErrMsg != "" {
|
if s.status != nil && s.status.LastErrMsg != "" {
|
||||||
msg = s.status.LastErrMsg
|
msg = s.status.LastErrMsg
|
||||||
}
|
}
|
||||||
|
if strings.Contains(msg, "unknown model") {
|
||||||
|
return fmt.Errorf("this model is not supported by your version of Ollama. You may need to upgrade")
|
||||||
|
}
|
||||||
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
|
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
@@ -662,7 +687,7 @@ type CompletionRequest struct {
|
|||||||
Prompt string
|
Prompt string
|
||||||
Format string
|
Format string
|
||||||
Images []ImageData
|
Images []ImageData
|
||||||
Options api.Options
|
Options *api.Options
|
||||||
}
|
}
|
||||||
|
|
||||||
type CompletionResponse struct {
|
type CompletionResponse struct {
|
||||||
@@ -682,10 +707,9 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
}
|
}
|
||||||
defer s.sem.Release(1)
|
defer s.sem.Release(1)
|
||||||
|
|
||||||
// only allow maximum 10 "context shifts" to avoid infinite generation
|
// put an upper limit on num_predict to avoid the model running on forever
|
||||||
if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
|
if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
|
||||||
req.Options.NumPredict = 10 * s.options.NumCtx
|
req.Options.NumPredict = 10 * s.options.NumCtx
|
||||||
slog.Debug("setting token limit to 10x num_ctx", "num_ctx", s.options.NumCtx, "num_predict", req.Options.NumPredict)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
request := map[string]any{
|
request := map[string]any{
|
||||||
@@ -843,15 +867,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingRequest struct {
|
type EmbedRequest struct {
|
||||||
Content string `json:"content"`
|
Content []string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingResponse struct {
|
type EmbedResponse struct {
|
||||||
Embedding []float64 `json:"embedding"`
|
Embedding [][]float32 `json:"embedding"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
|
||||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||||
slog.Error("Failed to acquire semaphore", "error", err)
|
slog.Error("Failed to acquire semaphore", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -866,7 +890,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
|
|||||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(TokenizeRequest{Content: prompt})
|
data, err := json.Marshal(EmbedRequest{Content: input})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||||
}
|
}
|
||||||
@@ -893,7 +917,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
|
|||||||
return nil, fmt.Errorf("%s", body)
|
return nil, fmt.Errorf("%s", body)
|
||||||
}
|
}
|
||||||
|
|
||||||
var embedding EmbeddingResponse
|
var embedding EmbedResponse
|
||||||
if err := json.Unmarshal(body, &embedding); err != nil {
|
if err := json.Unmarshal(body, &embedding); err != nil {
|
||||||
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||||
}
|
}
|
||||||
|
@@ -25,6 +25,7 @@ var errorPrefixes = []string{
|
|||||||
"CUDA error",
|
"CUDA error",
|
||||||
"cudaMalloc failed",
|
"cudaMalloc failed",
|
||||||
"\"ERR\"",
|
"\"ERR\"",
|
||||||
|
"error loading model",
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StatusWriter) Write(b []byte) (int, error) {
|
func (w *StatusWriter) Write(b []byte) (int, error) {
|
||||||
|
573
openai/openai.go
573
openai/openai.go
@@ -3,15 +3,18 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Error struct {
|
type Error struct {
|
||||||
@@ -27,7 +30,7 @@ type ErrorResponse struct {
|
|||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content any `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Choice struct {
|
type Choice struct {
|
||||||
@@ -42,6 +45,12 @@ type ChunkChoice struct {
|
|||||||
FinishReason *string `json:"finish_reason"`
|
FinishReason *string `json:"finish_reason"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CompleteChunkChoice struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
FinishReason *string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
|
||||||
type Usage struct {
|
type Usage struct {
|
||||||
PromptTokens int `json:"prompt_tokens"`
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
CompletionTokens int `json:"completion_tokens"`
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
@@ -52,6 +61,11 @@ type ResponseFormat struct {
|
|||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmbedRequest struct {
|
||||||
|
Input any `json:"input"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
}
|
||||||
|
|
||||||
type ChatCompletionRequest struct {
|
type ChatCompletionRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
@@ -85,6 +99,63 @@ type ChatCompletionChunk struct {
|
|||||||
Choices []ChunkChoice `json:"choices"`
|
Choices []ChunkChoice `json:"choices"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
|
||||||
|
type CompletionRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
FrequencyPenalty float32 `json:"frequency_penalty"`
|
||||||
|
MaxTokens *int `json:"max_tokens"`
|
||||||
|
PresencePenalty float32 `json:"presence_penalty"`
|
||||||
|
Seed *int `json:"seed"`
|
||||||
|
Stop any `json:"stop"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
Temperature *float32 `json:"temperature"`
|
||||||
|
TopP float32 `json:"top_p"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Completion struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
|
Choices []CompleteChunkChoice `json:"choices"`
|
||||||
|
Usage Usage `json:"usage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompletionChunk struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Choices []CompleteChunkChoice `json:"choices"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
OwnedBy string `json:"owned_by"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Embedding struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Embedding []float32 `json:"embedding"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ListCompletion struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []Model `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddingList struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []Embedding `json:"data"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
}
|
||||||
|
|
||||||
func NewError(code int, message string) ErrorResponse {
|
func NewError(code int, message string) ErrorResponse {
|
||||||
var etype string
|
var etype string
|
||||||
switch code {
|
switch code {
|
||||||
@@ -145,10 +216,159 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func fromRequest(r ChatCompletionRequest) api.ChatRequest {
|
func toCompletion(id string, r api.GenerateResponse) Completion {
|
||||||
|
return Completion{
|
||||||
|
Id: id,
|
||||||
|
Object: "text_completion",
|
||||||
|
Created: r.CreatedAt.Unix(),
|
||||||
|
Model: r.Model,
|
||||||
|
SystemFingerprint: "fp_ollama",
|
||||||
|
Choices: []CompleteChunkChoice{{
|
||||||
|
Text: r.Response,
|
||||||
|
Index: 0,
|
||||||
|
FinishReason: func(reason string) *string {
|
||||||
|
if len(reason) > 0 {
|
||||||
|
return &reason
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}(r.DoneReason),
|
||||||
|
}},
|
||||||
|
Usage: Usage{
|
||||||
|
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
|
||||||
|
PromptTokens: r.PromptEvalCount,
|
||||||
|
CompletionTokens: r.EvalCount,
|
||||||
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
||||||
|
return CompletionChunk{
|
||||||
|
Id: id,
|
||||||
|
Object: "text_completion",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
Model: r.Model,
|
||||||
|
SystemFingerprint: "fp_ollama",
|
||||||
|
Choices: []CompleteChunkChoice{{
|
||||||
|
Text: r.Response,
|
||||||
|
Index: 0,
|
||||||
|
FinishReason: func(reason string) *string {
|
||||||
|
if len(reason) > 0 {
|
||||||
|
return &reason
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}(r.DoneReason),
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toListCompletion(r api.ListResponse) ListCompletion {
|
||||||
|
var data []Model
|
||||||
|
for _, m := range r.Models {
|
||||||
|
data = append(data, Model{
|
||||||
|
Id: m.Name,
|
||||||
|
Object: "model",
|
||||||
|
Created: m.ModifiedAt.Unix(),
|
||||||
|
OwnedBy: model.ParseName(m.Name).Namespace,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return ListCompletion{
|
||||||
|
Object: "list",
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
||||||
|
if r.Embeddings != nil {
|
||||||
|
var data []Embedding
|
||||||
|
for i, e := range r.Embeddings {
|
||||||
|
data = append(data, Embedding{
|
||||||
|
Object: "embedding",
|
||||||
|
Embedding: e,
|
||||||
|
Index: i,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return EmbeddingList{
|
||||||
|
Object: "list",
|
||||||
|
Data: data,
|
||||||
|
Model: model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return EmbeddingList{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toModel(r api.ShowResponse, m string) Model {
|
||||||
|
return Model{
|
||||||
|
Id: m,
|
||||||
|
Object: "model",
|
||||||
|
Created: r.ModifiedAt.Unix(),
|
||||||
|
OwnedBy: model.ParseName(m).Namespace,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||||
var messages []api.Message
|
var messages []api.Message
|
||||||
for _, msg := range r.Messages {
|
for _, msg := range r.Messages {
|
||||||
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
|
switch content := msg.Content.(type) {
|
||||||
|
case string:
|
||||||
|
messages = append(messages, api.Message{Role: msg.Role, Content: content})
|
||||||
|
case []any:
|
||||||
|
message := api.Message{Role: msg.Role}
|
||||||
|
for _, c := range content {
|
||||||
|
data, ok := c.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid message format")
|
||||||
|
}
|
||||||
|
switch data["type"] {
|
||||||
|
case "text":
|
||||||
|
text, ok := data["text"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid message format")
|
||||||
|
}
|
||||||
|
message.Content = text
|
||||||
|
case "image_url":
|
||||||
|
var url string
|
||||||
|
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
||||||
|
if url, ok = urlMap["url"].(string); !ok {
|
||||||
|
return nil, fmt.Errorf("invalid message format")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if url, ok = data["image_url"].(string); !ok {
|
||||||
|
return nil, fmt.Errorf("invalid message format")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
types := []string{"jpeg", "jpg", "png"}
|
||||||
|
valid := false
|
||||||
|
for _, t := range types {
|
||||||
|
prefix := "data:image/" + t + ";base64,"
|
||||||
|
if strings.HasPrefix(url, prefix) {
|
||||||
|
url = strings.TrimPrefix(url, prefix)
|
||||||
|
valid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
return nil, fmt.Errorf("invalid image input")
|
||||||
|
}
|
||||||
|
|
||||||
|
img, err := base64.StdEncoding.DecodeString(url)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid message format")
|
||||||
|
}
|
||||||
|
message.Images = append(message.Images, img)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid message format")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messages = append(messages, message)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid message content type: %T", content)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
options := make(map[string]interface{})
|
options := make(map[string]interface{})
|
||||||
@@ -156,7 +376,7 @@ func fromRequest(r ChatCompletionRequest) api.ChatRequest {
|
|||||||
switch stop := r.Stop.(type) {
|
switch stop := r.Stop.(type) {
|
||||||
case string:
|
case string:
|
||||||
options["stop"] = []string{stop}
|
options["stop"] = []string{stop}
|
||||||
case []interface{}:
|
case []any:
|
||||||
var stops []string
|
var stops []string
|
||||||
for _, s := range stop {
|
for _, s := range stop {
|
||||||
if str, ok := s.(string); ok {
|
if str, ok := s.(string); ok {
|
||||||
@@ -199,22 +419,96 @@ func fromRequest(r ChatCompletionRequest) api.ChatRequest {
|
|||||||
format = "json"
|
format = "json"
|
||||||
}
|
}
|
||||||
|
|
||||||
return api.ChatRequest{
|
return &api.ChatRequest{
|
||||||
Model: r.Model,
|
Model: r.Model,
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
Format: format,
|
Format: format,
|
||||||
Options: options,
|
Options: options,
|
||||||
Stream: &r.Stream,
|
Stream: &r.Stream,
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type writer struct {
|
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||||
stream bool
|
options := make(map[string]any)
|
||||||
id string
|
|
||||||
|
switch stop := r.Stop.(type) {
|
||||||
|
case string:
|
||||||
|
options["stop"] = []string{stop}
|
||||||
|
case []any:
|
||||||
|
var stops []string
|
||||||
|
for _, s := range stop {
|
||||||
|
if str, ok := s.(string); ok {
|
||||||
|
stops = append(stops, str)
|
||||||
|
} else {
|
||||||
|
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
options["stop"] = stops
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.MaxTokens != nil {
|
||||||
|
options["num_predict"] = *r.MaxTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Temperature != nil {
|
||||||
|
options["temperature"] = *r.Temperature * 2.0
|
||||||
|
} else {
|
||||||
|
options["temperature"] = 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Seed != nil {
|
||||||
|
options["seed"] = *r.Seed
|
||||||
|
}
|
||||||
|
|
||||||
|
options["frequency_penalty"] = r.FrequencyPenalty * 2.0
|
||||||
|
|
||||||
|
options["presence_penalty"] = r.PresencePenalty * 2.0
|
||||||
|
|
||||||
|
if r.TopP != 0.0 {
|
||||||
|
options["top_p"] = r.TopP
|
||||||
|
} else {
|
||||||
|
options["top_p"] = 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
return api.GenerateRequest{
|
||||||
|
Model: r.Model,
|
||||||
|
Prompt: r.Prompt,
|
||||||
|
Options: options,
|
||||||
|
Stream: &r.Stream,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaseWriter struct {
|
||||||
gin.ResponseWriter
|
gin.ResponseWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *writer) writeError(code int, data []byte) (int, error) {
|
type ChatWriter struct {
|
||||||
|
stream bool
|
||||||
|
id string
|
||||||
|
BaseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type CompleteWriter struct {
|
||||||
|
stream bool
|
||||||
|
id string
|
||||||
|
BaseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type ListWriter struct {
|
||||||
|
BaseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type RetrieveWriter struct {
|
||||||
|
BaseWriter
|
||||||
|
model string
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbedWriter struct {
|
||||||
|
BaseWriter
|
||||||
|
model string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
|
||||||
var serr api.StatusError
|
var serr api.StatusError
|
||||||
err := json.Unmarshal(data, &serr)
|
err := json.Unmarshal(data, &serr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -230,7 +524,7 @@ func (w *writer) writeError(code int, data []byte) (int, error) {
|
|||||||
return len(data), nil
|
return len(data), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *writer) writeResponse(data []byte) (int, error) {
|
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||||
var chatResponse api.ChatResponse
|
var chatResponse api.ChatResponse
|
||||||
err := json.Unmarshal(data, &chatResponse)
|
err := json.Unmarshal(data, &chatResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -270,7 +564,7 @@ func (w *writer) writeResponse(data []byte) (int, error) {
|
|||||||
return len(data), nil
|
return len(data), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *writer) Write(data []byte) (int, error) {
|
func (w *ChatWriter) Write(data []byte) (int, error) {
|
||||||
code := w.ResponseWriter.Status()
|
code := w.ResponseWriter.Status()
|
||||||
if code != http.StatusOK {
|
if code != http.StatusOK {
|
||||||
return w.writeError(code, data)
|
return w.writeError(code, data)
|
||||||
@@ -279,7 +573,244 @@ func (w *writer) Write(data []byte) (int, error) {
|
|||||||
return w.writeResponse(data)
|
return w.writeResponse(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Middleware() gin.HandlerFunc {
|
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var generateResponse api.GenerateResponse
|
||||||
|
err := json.Unmarshal(data, &generateResponse)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// completion chunk
|
||||||
|
if w.stream {
|
||||||
|
d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if generateResponse.Done {
|
||||||
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// completion
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(code, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var listResponse api.ListResponse
|
||||||
|
err := json.Unmarshal(data, &listResponse)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ListWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(code, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var showResponse api.ShowResponse
|
||||||
|
err := json.Unmarshal(data, &showResponse)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// retrieve completion
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(code, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var embedResponse api.EmbedResponse
|
||||||
|
err := json.Unmarshal(data, &embedResponse)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(code, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
w := &ListWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RetrieveMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
// response writer
|
||||||
|
w := &RetrieveWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
model: c.Param("model"),
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func CompletionsMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var req CompletionRequest
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
genReq, err := fromCompleteRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
w := &CompleteWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
stream: req.Stream,
|
||||||
|
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func EmbeddingsMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var req EmbedRequest
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Input == "" {
|
||||||
|
req.Input = []string{""}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Input == nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := req.Input.([]any); ok && len(v) == 0 {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
w := &EmbedWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
model: req.Model,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ChatMiddleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
var req ChatCompletionRequest
|
var req ChatCompletionRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
@@ -294,17 +825,23 @@ func Middleware() gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := json.NewEncoder(&b).Encode(fromRequest(req)); err != nil {
|
|
||||||
|
chatReq, err := fromChatRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(&b)
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
w := &writer{
|
w := &ChatWriter{
|
||||||
ResponseWriter: c.Writer,
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
stream: req.Stream,
|
stream: req.Stream,
|
||||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Writer = w
|
c.Writer = w
|
||||||
|
392
openai/openai_test.go
Normal file
392
openai/openai_test.go
Normal file
@@ -0,0 +1,392 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
const prefix = `data:image/jpeg;base64,`
|
||||||
|
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||||
|
const imageURL = prefix + image
|
||||||
|
|
||||||
|
func TestMiddlewareRequests(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
Name string
|
||||||
|
Method string
|
||||||
|
Path string
|
||||||
|
Handler func() gin.HandlerFunc
|
||||||
|
Setup func(t *testing.T, req *http.Request)
|
||||||
|
Expected func(t *testing.T, req *http.Request)
|
||||||
|
}
|
||||||
|
|
||||||
|
var capturedRequest *http.Request
|
||||||
|
|
||||||
|
captureRequestMiddleware := func() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
capturedRequest = c.Request
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
Name: "chat handler",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/chat",
|
||||||
|
Handler: ChatMiddleware,
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := ChatCompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []Message{{Role: "user", Content: "Hello"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *http.Request) {
|
||||||
|
var chatReq api.ChatRequest
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chatReq.Messages[0].Role != "user" {
|
||||||
|
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chatReq.Messages[0].Content != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "completions handler",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/generate",
|
||||||
|
Handler: CompletionsMiddleware,
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
temp := float32(0.8)
|
||||||
|
body := CompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Temperature: &temp,
|
||||||
|
Stop: []string{"\n", "stop"},
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *http.Request) {
|
||||||
|
var genReq api.GenerateRequest
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if genReq.Prompt != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if genReq.Options["temperature"] != 1.6 {
|
||||||
|
t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
|
||||||
|
}
|
||||||
|
|
||||||
|
stopTokens, ok := genReq.Options["stop"].([]any)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected stop tokens to be a list")
|
||||||
|
}
|
||||||
|
|
||||||
|
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
|
||||||
|
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "chat handler with image content",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/chat",
|
||||||
|
Handler: ChatMiddleware,
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := ChatCompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []Message{
|
||||||
|
{
|
||||||
|
Role: "user", Content: []map[string]any{
|
||||||
|
{"type": "text", "text": "Hello"},
|
||||||
|
{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *http.Request) {
|
||||||
|
var chatReq api.ChatRequest
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chatReq.Messages[0].Role != "user" {
|
||||||
|
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chatReq.Messages[0].Content != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
|
||||||
|
|
||||||
|
if !bytes.Equal(chatReq.Messages[0].Images[0], img) {
|
||||||
|
t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "embed handler single input",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/embed",
|
||||||
|
Handler: EmbeddingsMiddleware,
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := EmbedRequest{
|
||||||
|
Input: "Hello",
|
||||||
|
Model: "test-model",
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *http.Request) {
|
||||||
|
var embedReq api.EmbedRequest
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if embedReq.Input != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %s", embedReq.Input)
|
||||||
|
}
|
||||||
|
|
||||||
|
if embedReq.Model != "test-model" {
|
||||||
|
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "embed handler batch input",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/embed",
|
||||||
|
Handler: EmbeddingsMiddleware,
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := EmbedRequest{
|
||||||
|
Input: []string{"Hello", "World"},
|
||||||
|
Model: "test-model",
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *http.Request) {
|
||||||
|
var embedReq api.EmbedRequest
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
input, ok := embedReq.Input.([]any)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected input to be a list")
|
||||||
|
}
|
||||||
|
|
||||||
|
if input[0].(string) != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %s", input[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if input[1].(string) != "World" {
|
||||||
|
t.Fatalf("expected 'World', got %s", input[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
if embedReq.Model != "test-model" {
|
||||||
|
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
|
||||||
|
endpoint := func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
router = gin.New()
|
||||||
|
router.Use(captureRequestMiddleware())
|
||||||
|
router.Use(tc.Handler())
|
||||||
|
router.Handle(tc.Method, tc.Path, endpoint)
|
||||||
|
req, _ := http.NewRequest(tc.Method, tc.Path, nil)
|
||||||
|
|
||||||
|
if tc.Setup != nil {
|
||||||
|
tc.Setup(t, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
tc.Expected(t, capturedRequest)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddlewareResponses(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
Name string
|
||||||
|
Method string
|
||||||
|
Path string
|
||||||
|
TestPath string
|
||||||
|
Handler func() gin.HandlerFunc
|
||||||
|
Endpoint func(c *gin.Context)
|
||||||
|
Setup func(t *testing.T, req *http.Request)
|
||||||
|
Expected func(t *testing.T, resp *httptest.ResponseRecorder)
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
Name: "completions handler error forwarding",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/generate",
|
||||||
|
TestPath: "/api/generate",
|
||||||
|
Handler: CompletionsMiddleware,
|
||||||
|
Endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
||||||
|
},
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := CompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||||
|
if resp.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", resp.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(resp.Body.String(), `"invalid request"`) {
|
||||||
|
t.Fatalf("error was not forwarded")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "list handler",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/api/tags",
|
||||||
|
TestPath: "/api/tags",
|
||||||
|
Handler: ListMiddleware,
|
||||||
|
Endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, api.ListResponse{
|
||||||
|
Models: []api.ListModelResponse{
|
||||||
|
{
|
||||||
|
Name: "Test Model",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||||
|
assert.Equal(t, http.StatusOK, resp.Code)
|
||||||
|
|
||||||
|
var listResp ListCompletion
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if listResp.Object != "list" {
|
||||||
|
t.Fatalf("expected list, got %s", listResp.Object)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(listResp.Data) != 1 {
|
||||||
|
t.Fatalf("expected 1, got %d", len(listResp.Data))
|
||||||
|
}
|
||||||
|
|
||||||
|
if listResp.Data[0].Id != "Test Model" {
|
||||||
|
t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "retrieve model",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/api/show/:model",
|
||||||
|
TestPath: "/api/show/test-model",
|
||||||
|
Handler: RetrieveMiddleware,
|
||||||
|
Endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, api.ShowResponse{
|
||||||
|
ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
|
||||||
|
})
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||||
|
var retrieveResp Model
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if retrieveResp.Object != "model" {
|
||||||
|
t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
|
||||||
|
}
|
||||||
|
|
||||||
|
if retrieveResp.Id != "test-model" {
|
||||||
|
t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
router = gin.New()
|
||||||
|
router.Use(tc.Handler())
|
||||||
|
router.Handle(tc.Method, tc.Path, tc.Endpoint)
|
||||||
|
req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
|
||||||
|
|
||||||
|
if tc.Setup != nil {
|
||||||
|
tc.Setup(t, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
tc.Expected(t, resp)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@@ -124,7 +124,7 @@ func ParseFile(r io.Reader) (*File, error) {
|
|||||||
case stateComment, stateNil:
|
case stateComment, stateNil:
|
||||||
// pass
|
// pass
|
||||||
case stateValue:
|
case stateValue:
|
||||||
s, ok := unquote(b.String())
|
s, ok := unquote(strings.TrimSpace(b.String()))
|
||||||
if !ok || isSpace(r) {
|
if !ok || isSpace(r) {
|
||||||
if _, err := b.WriteRune(r); err != nil {
|
if _, err := b.WriteRune(r); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -158,7 +158,7 @@ func ParseFile(r io.Reader) (*File, error) {
|
|||||||
case stateComment, stateNil:
|
case stateComment, stateNil:
|
||||||
// pass; nothing to flush
|
// pass; nothing to flush
|
||||||
case stateValue:
|
case stateValue:
|
||||||
s, ok := unquote(b.String())
|
s, ok := unquote(strings.TrimSpace(b.String()))
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, io.ErrUnexpectedEOF
|
return nil, io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
|
@@ -22,7 +22,13 @@ ADAPTER adapter1
|
|||||||
LICENSE MIT
|
LICENSE MIT
|
||||||
PARAMETER param1 value1
|
PARAMETER param1 value1
|
||||||
PARAMETER param2 value2
|
PARAMETER param2 value2
|
||||||
TEMPLATE template1
|
TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .Response }}<|eot_id|>"""
|
||||||
`
|
`
|
||||||
|
|
||||||
reader := strings.NewReader(input)
|
reader := strings.NewReader(input)
|
||||||
@@ -36,7 +42,40 @@ TEMPLATE template1
|
|||||||
{Name: "license", Args: "MIT"},
|
{Name: "license", Args: "MIT"},
|
||||||
{Name: "param1", Args: "value1"},
|
{Name: "param1", Args: "value1"},
|
||||||
{Name: "param2", Args: "value2"},
|
{Name: "param2", Args: "value2"},
|
||||||
{Name: "template", Args: "template1"},
|
{Name: "template", Args: "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>"},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expectedCommands, modelfile.Commands)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFileTrimSpace(t *testing.T) {
|
||||||
|
input := `
|
||||||
|
FROM " model 1"
|
||||||
|
ADAPTER adapter3
|
||||||
|
LICENSE "MIT "
|
||||||
|
PARAMETER param1 value1
|
||||||
|
PARAMETER param2 value2
|
||||||
|
TEMPLATE """ {{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
{{ .Response }}<|eot_id|> """
|
||||||
|
`
|
||||||
|
|
||||||
|
reader := strings.NewReader(input)
|
||||||
|
|
||||||
|
modelfile, err := ParseFile(reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
expectedCommands := []Command{
|
||||||
|
{Name: "model", Args: " model 1"},
|
||||||
|
{Name: "adapter", Args: "adapter3"},
|
||||||
|
{Name: "license", Args: "MIT "},
|
||||||
|
{Name: "param1", Args: "value1"},
|
||||||
|
{Name: "param2", Args: "value2"},
|
||||||
|
{Name: "template", Args: " {{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|> "},
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, expectedCommands, modelfile.Commands)
|
assert.Equal(t, expectedCommands, modelfile.Commands)
|
||||||
@@ -48,6 +87,26 @@ func TestParseFileFrom(t *testing.T) {
|
|||||||
expected []Command
|
expected []Command
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
|
{
|
||||||
|
"FROM \"FOO BAR \"",
|
||||||
|
[]Command{{Name: "model", Args: "FOO BAR "}},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"FROM \"FOO BAR\"\nPARAMETER param1 value1",
|
||||||
|
[]Command{{Name: "model", Args: "FOO BAR"}, {Name: "param1", Args: "value1"}},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"FROM FOOO BAR ",
|
||||||
|
[]Command{{Name: "model", Args: "FOOO BAR"}},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"FROM /what/is/the path ",
|
||||||
|
[]Command{{Name: "model", Args: "/what/is/the path"}},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"FROM foo",
|
"FROM foo",
|
||||||
[]Command{{Name: "model", Args: "foo"}},
|
[]Command{{Name: "model", Args: "foo"}},
|
||||||
@@ -86,6 +145,11 @@ func TestParseFileFrom(t *testing.T) {
|
|||||||
[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
|
[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"PARAMETER what the \nFROM lemons make lemonade ",
|
||||||
|
[]Command{{Name: "what", Args: "the"}, {Name: "model", Args: "lemons make lemonade"}},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
@@ -399,7 +463,7 @@ func TestParseFileParameters(t *testing.T) {
|
|||||||
"mirostat_eta 1.0": {"mirostat_eta", "1.0"},
|
"mirostat_eta 1.0": {"mirostat_eta", "1.0"},
|
||||||
"penalize_newline true": {"penalize_newline", "true"},
|
"penalize_newline true": {"penalize_newline", "true"},
|
||||||
"stop ### User:": {"stop", "### User:"},
|
"stop ### User:": {"stop", "### User:"},
|
||||||
"stop ### User: ": {"stop", "### User: "},
|
"stop ### User: ": {"stop", "### User:"},
|
||||||
"stop \"### User:\"": {"stop", "### User:"},
|
"stop \"### User:\"": {"stop", "### User:"},
|
||||||
"stop \"### User: \"": {"stop", "### User: "},
|
"stop \"### User: \"": {"stop", "### User: "},
|
||||||
"stop \"\"\"### User:\"\"\"": {"stop", "### User:"},
|
"stop \"\"\"### User:\"\"\"": {"stop", "### User:"},
|
||||||
|
@@ -107,9 +107,12 @@ function gatherDependencies() {
|
|||||||
|
|
||||||
# TODO - this varies based on host build system and MSVC version - drive from dumpbin output
|
# TODO - this varies based on host build system and MSVC version - drive from dumpbin output
|
||||||
# currently works for Win11 + MSVC 2019 + Cuda V11
|
# currently works for Win11 + MSVC 2019 + Cuda V11
|
||||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140.dll" "${script:DEPS_DIR}\ollama_runners\"
|
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140*.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\ollama_runners\"
|
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||||
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\ollama_runners\"
|
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||||
|
foreach ($part in $("runtime", "stdio", "filesystem", "math", "convert", "heap", "string", "time", "locale", "environment")) {
|
||||||
|
cp "$env:VCToolsRedistDir\..\..\..\Tools\Llvm\x64\bin\api-ms-win-crt-${part}*.dll" "${script:DEPS_DIR}\ollama_runners\"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"
|
cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"
|
||||||
|
@@ -6,10 +6,21 @@ set -ex
|
|||||||
MACHINE=$(uname -m)
|
MACHINE=$(uname -m)
|
||||||
|
|
||||||
if grep -i "centos" /etc/system-release >/dev/null; then
|
if grep -i "centos" /etc/system-release >/dev/null; then
|
||||||
|
# As of 7/1/2024 mirrorlist.centos.org has been taken offline, so adjust accordingly
|
||||||
|
sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo
|
||||||
|
sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo
|
||||||
|
sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo
|
||||||
|
|
||||||
# Centos 7 derivatives have too old of a git version to run our generate script
|
# Centos 7 derivatives have too old of a git version to run our generate script
|
||||||
# uninstall and ignore failures
|
# uninstall and ignore failures
|
||||||
yum remove -y git
|
yum remove -y git
|
||||||
yum -y install epel-release centos-release-scl
|
yum -y install epel-release centos-release-scl
|
||||||
|
|
||||||
|
# The release packages reinstate the mirrors, undo that again
|
||||||
|
sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo
|
||||||
|
sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo
|
||||||
|
sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo
|
||||||
|
|
||||||
yum -y install dnf
|
yum -y install dnf
|
||||||
if [ "${MACHINE}" = "x86_64" ]; then
|
if [ "${MACHINE}" = "x86_64" ]; then
|
||||||
yum -y install https://repo.ius.io/ius-release-el7.rpm
|
yum -y install https://repo.ius.io/ius-release-el7.rpm
|
||||||
|
104
server/images.go
104
server/images.go
@@ -28,11 +28,27 @@ import (
|
|||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errCapabilities = errors.New("does not support")
|
||||||
|
errCapabilityCompletion = errors.New("completion")
|
||||||
|
errCapabilityTools = errors.New("tools")
|
||||||
|
errCapabilityInsert = errors.New("insert")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Capability string
|
||||||
|
|
||||||
|
const (
|
||||||
|
CapabilityCompletion = Capability("completion")
|
||||||
|
CapabilityTools = Capability("tools")
|
||||||
|
CapabilityInsert = Capability("insert")
|
||||||
|
)
|
||||||
|
|
||||||
type registryOptions struct {
|
type registryOptions struct {
|
||||||
Insecure bool
|
Insecure bool
|
||||||
Username string
|
Username string
|
||||||
@@ -48,16 +64,59 @@ type Model struct {
|
|||||||
ParentModel string
|
ParentModel string
|
||||||
AdapterPaths []string
|
AdapterPaths []string
|
||||||
ProjectorPaths []string
|
ProjectorPaths []string
|
||||||
Template string
|
|
||||||
System string
|
System string
|
||||||
License []string
|
License []string
|
||||||
Digest string
|
Digest string
|
||||||
Options map[string]interface{}
|
Options map[string]interface{}
|
||||||
Messages []Message
|
Messages []Message
|
||||||
|
|
||||||
|
Template *template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) IsEmbedding() bool {
|
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
|
||||||
return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert")
|
// any missing or unknown capabilities
|
||||||
|
func (m *Model) CheckCapabilities(caps ...Capability) error {
|
||||||
|
var errs []error
|
||||||
|
for _, cap := range caps {
|
||||||
|
switch cap {
|
||||||
|
case CapabilityCompletion:
|
||||||
|
f, err := os.Open(m.ModelPath)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("couldn't open model file", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
// TODO(mxyng): decode the GGML into model to avoid doing this multiple times
|
||||||
|
ggml, _, err := llm.DecodeGGML(f, 0)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("couldn't decode ggml", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
|
||||||
|
errs = append(errs, errCapabilityCompletion)
|
||||||
|
}
|
||||||
|
case CapabilityTools:
|
||||||
|
if !slices.Contains(m.Template.Vars(), "tools") {
|
||||||
|
errs = append(errs, errCapabilityTools)
|
||||||
|
}
|
||||||
|
case CapabilityInsert:
|
||||||
|
vars := m.Template.Vars()
|
||||||
|
if !slices.Contains(vars, "suffix") {
|
||||||
|
errs = append(errs, errCapabilityInsert)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
slog.Error("unknown capability", "capability", cap)
|
||||||
|
return fmt.Errorf("unknown capability: %s", cap)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := errors.Join(errs...); err != nil {
|
||||||
|
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) String() string {
|
func (m *Model) String() string {
|
||||||
@@ -82,10 +141,10 @@ func (m *Model) String() string {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.Template != "" {
|
if m.Template != nil {
|
||||||
modelfile.Commands = append(modelfile.Commands, parser.Command{
|
modelfile.Commands = append(modelfile.Commands, parser.Command{
|
||||||
Name: "template",
|
Name: "template",
|
||||||
Args: m.Template,
|
Args: m.Template.String(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,13 +194,6 @@ type Message struct {
|
|||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ManifestV2 struct {
|
|
||||||
SchemaVersion int `json:"schemaVersion"`
|
|
||||||
MediaType string `json:"mediaType"`
|
|
||||||
Config *Layer `json:"config"`
|
|
||||||
Layers []*Layer `json:"layers"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ConfigV2 struct {
|
type ConfigV2 struct {
|
||||||
ModelFormat string `json:"model_format"`
|
ModelFormat string `json:"model_format"`
|
||||||
ModelFamily string `json:"model_family"`
|
ModelFamily string `json:"model_family"`
|
||||||
@@ -160,7 +212,7 @@ type RootFS struct {
|
|||||||
DiffIDs []string `json:"diff_ids"`
|
DiffIDs []string `json:"diff_ids"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
|
func GetManifest(mp ModelPath) (*Manifest, string, error) {
|
||||||
fp, err := mp.GetManifestPath()
|
fp, err := mp.GetManifestPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
@@ -170,7 +222,7 @@ func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
|
|||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
var manifest *ManifestV2
|
var manifest *Manifest
|
||||||
|
|
||||||
bts, err := os.ReadFile(fp)
|
bts, err := os.ReadFile(fp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -198,8 +250,7 @@ func GetModel(name string) (*Model, error) {
|
|||||||
Name: mp.GetFullTagname(),
|
Name: mp.GetFullTagname(),
|
||||||
ShortName: mp.GetShortTagname(),
|
ShortName: mp.GetShortTagname(),
|
||||||
Digest: digest,
|
Digest: digest,
|
||||||
Template: "{{ .Prompt }}",
|
Template: template.DefaultTemplate,
|
||||||
License: []string{},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
filename, err := GetBlobsPath(manifest.Config.Digest)
|
filename, err := GetBlobsPath(manifest.Config.Digest)
|
||||||
@@ -235,13 +286,17 @@ func GetModel(name string) (*Model, error) {
|
|||||||
model.AdapterPaths = append(model.AdapterPaths, filename)
|
model.AdapterPaths = append(model.AdapterPaths, filename)
|
||||||
case "application/vnd.ollama.image.projector":
|
case "application/vnd.ollama.image.projector":
|
||||||
model.ProjectorPaths = append(model.ProjectorPaths, filename)
|
model.ProjectorPaths = append(model.ProjectorPaths, filename)
|
||||||
case "application/vnd.ollama.image.template":
|
case "application/vnd.ollama.image.prompt",
|
||||||
|
"application/vnd.ollama.image.template":
|
||||||
bts, err := os.ReadFile(filename)
|
bts, err := os.ReadFile(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
model.Template = string(bts)
|
model.Template, err = template.Parse(string(bts))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
case "application/vnd.ollama.image.system":
|
case "application/vnd.ollama.image.system":
|
||||||
bts, err := os.ReadFile(filename)
|
bts, err := os.ReadFile(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -249,13 +304,6 @@ func GetModel(name string) (*Model, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
model.System = string(bts)
|
model.System = string(bts)
|
||||||
case "application/vnd.ollama.image.prompt":
|
|
||||||
bts, err := os.ReadFile(filename)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
model.Template = string(bts)
|
|
||||||
case "application/vnd.ollama.image.params":
|
case "application/vnd.ollama.image.params":
|
||||||
params, err := os.Open(filename)
|
params, err := os.Open(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -822,7 +870,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||||
mp := ParseModelPath(name)
|
mp := ParseModelPath(name)
|
||||||
|
|
||||||
var manifest *ManifestV2
|
var manifest *Manifest
|
||||||
var err error
|
var err error
|
||||||
var noprune string
|
var noprune string
|
||||||
|
|
||||||
@@ -929,7 +977,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) {
|
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
|
||||||
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||||
|
|
||||||
headers := make(http.Header)
|
headers := make(http.Header)
|
||||||
@@ -940,7 +988,7 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
var m *ManifestV2
|
var m *Manifest
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@@ -14,7 +14,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Manifest struct {
|
type Manifest struct {
|
||||||
ManifestV2
|
SchemaVersion int `json:"schemaVersion"`
|
||||||
|
MediaType string `json:"mediaType"`
|
||||||
|
Config *Layer `json:"config"`
|
||||||
|
Layers []*Layer `json:"layers"`
|
||||||
|
|
||||||
filepath string
|
filepath string
|
||||||
fi os.FileInfo
|
fi os.FileInfo
|
||||||
@@ -66,7 +69,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
|||||||
|
|
||||||
p := filepath.Join(manifests, n.Filepath())
|
p := filepath.Join(manifests, n.Filepath())
|
||||||
|
|
||||||
var m ManifestV2
|
var m Manifest
|
||||||
f, err := os.Open(p)
|
f, err := os.Open(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -83,12 +86,11 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Manifest{
|
m.filepath = p
|
||||||
ManifestV2: m,
|
m.fi = fi
|
||||||
filepath: p,
|
m.digest = fmt.Sprintf("%x", sha256sum.Sum(nil))
|
||||||
fi: fi,
|
|
||||||
digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
|
return &m, nil
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
|
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
|
||||||
@@ -108,7 +110,7 @@ func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
|
|||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
m := ManifestV2{
|
m := Manifest{
|
||||||
SchemaVersion: 2,
|
SchemaVersion: 2,
|
||||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||||
Config: config,
|
Config: config,
|
||||||
|
@@ -25,7 +25,7 @@ func createManifest(t *testing.T, path, name string) {
|
|||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil {
|
if err := json.NewEncoder(f).Encode(Manifest{}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -4,6 +4,7 @@ import (
|
|||||||
"archive/zip"
|
"archive/zip"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -11,12 +12,14 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
"text/template/parse"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/convert"
|
"github.com/ollama/ollama/convert"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/templates"
|
"github.com/ollama/ollama/template"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -91,12 +94,11 @@ func extractFromZipFile(p string, file *os.File, fn func(api.ProgressResponse))
|
|||||||
|
|
||||||
fn(api.ProgressResponse{Status: "unpacking model metadata"})
|
fn(api.ProgressResponse{Status: "unpacking model metadata"})
|
||||||
for _, f := range r.File {
|
for _, f := range r.File {
|
||||||
n := filepath.Join(p, f.Name)
|
if !filepath.IsLocal(f.Name) {
|
||||||
if !strings.HasPrefix(n, p) {
|
return fmt.Errorf("%w: %s", zip.ErrInsecurePath, f.Name)
|
||||||
slog.Warn("skipped extracting file outside of context", "name", f.Name)
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
n := filepath.Join(p, f.Name)
|
||||||
if err := os.MkdirAll(filepath.Dir(n), 0o750); err != nil {
|
if err := os.MkdirAll(filepath.Dir(n), 0o750); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -258,7 +260,7 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
|
|||||||
func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
if s := layer.GGML.KV().ChatTemplate(); s != "" {
|
if s := layer.GGML.KV().ChatTemplate(); s != "" {
|
||||||
if t, err := templates.NamedTemplate(s); err != nil {
|
if t, err := template.Named(s); err != nil {
|
||||||
slog.Debug("template detection", "error", err)
|
slog.Debug("template detection", "error", err)
|
||||||
} else {
|
} else {
|
||||||
tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||||
@@ -291,3 +293,87 @@ func detectContentType(r io.Reader) (string, error) {
|
|||||||
|
|
||||||
return "unknown", nil
|
return "unknown", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||||
|
// mxyng: this only really works if the input contains tool calls in some JSON format
|
||||||
|
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||||
|
// create a subtree from the node that ranges over .ToolCalls
|
||||||
|
tmpl := m.Template.Subtree(func(n parse.Node) bool {
|
||||||
|
if t, ok := n.(*parse.RangeNode); ok {
|
||||||
|
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
|
||||||
|
if tmpl == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := tmpl.Execute(&b, map[string][]map[string]any{
|
||||||
|
"ToolCalls": {
|
||||||
|
{
|
||||||
|
"Function": map[string]any{
|
||||||
|
"Name": "@@name@@",
|
||||||
|
"Arguments": "@@arguments@@",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var kv map[string]string
|
||||||
|
// execute the subtree with placeholders to identify the keys
|
||||||
|
// trim any commands that might exist in the template
|
||||||
|
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the keys that correspond to the name and arguments fields
|
||||||
|
var name, arguments string
|
||||||
|
for k, v := range kv {
|
||||||
|
switch v {
|
||||||
|
case "@@name@@":
|
||||||
|
name = k
|
||||||
|
case "@@arguments@@":
|
||||||
|
arguments = k
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var objs []map[string]any
|
||||||
|
for offset := 0; offset < len(s); {
|
||||||
|
if err := json.NewDecoder(strings.NewReader(s[offset:])).Decode(&objs); errors.Is(err, io.EOF) {
|
||||||
|
break
|
||||||
|
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
|
||||||
|
// skip over any syntax errors
|
||||||
|
offset += int(syntax.Offset)
|
||||||
|
} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
|
||||||
|
// skip over any unmarshalable types
|
||||||
|
offset += int(unmarshalType.Offset)
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, false
|
||||||
|
} else {
|
||||||
|
// break when an object is decoded
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var toolCalls []api.ToolCall
|
||||||
|
for _, kv := range objs {
|
||||||
|
var call api.ToolCall
|
||||||
|
for k, v := range kv {
|
||||||
|
switch k {
|
||||||
|
case name:
|
||||||
|
call.Function.Name = v.(string)
|
||||||
|
case arguments:
|
||||||
|
call.Function.Arguments = v.(map[string]any)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls = append(toolCalls, call)
|
||||||
|
}
|
||||||
|
|
||||||
|
return toolCalls, len(toolCalls) > 0
|
||||||
|
}
|
||||||
|
@@ -3,13 +3,19 @@ package server
|
|||||||
import (
|
import (
|
||||||
"archive/zip"
|
"archive/zip"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
func createZipFile(t *testing.T, name string) *os.File {
|
func createZipFile(t *testing.T, name string) *os.File {
|
||||||
@@ -39,13 +45,31 @@ func TestExtractFromZipFile(t *testing.T) {
|
|||||||
cases := []struct {
|
cases := []struct {
|
||||||
name string
|
name string
|
||||||
expect []string
|
expect []string
|
||||||
|
err error
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "good",
|
name: "good",
|
||||||
expect: []string{"good"},
|
expect: []string{"good"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: filepath.Join("..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "bad"),
|
name: strings.Join([]string{"path", "..", "to", "good"}, string(os.PathSeparator)),
|
||||||
|
expect: []string{filepath.Join("to", "good")},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: strings.Join([]string{"path", "..", "to", "..", "good"}, string(os.PathSeparator)),
|
||||||
|
expect: []string{"good"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: strings.Join([]string{"path", "to", "..", "..", "good"}, string(os.PathSeparator)),
|
||||||
|
expect: []string{"good"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: strings.Join([]string{"..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "bad"}, string(os.PathSeparator)),
|
||||||
|
err: zip.ErrInsecurePath,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: strings.Join([]string{"path", "..", "..", "to", "bad"}, string(os.PathSeparator)),
|
||||||
|
err: zip.ErrInsecurePath,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,7 +79,7 @@ func TestExtractFromZipFile(t *testing.T) {
|
|||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
if err := extractFromZipFile(tempDir, f, func(api.ProgressResponse) {}); err != nil {
|
if err := extractFromZipFile(tempDir, f, func(api.ProgressResponse) {}); !errors.Is(err, tt.err) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -90,3 +114,123 @@ func TestExtractFromZipFile(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type function struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments map[string]any `json:"arguments"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
bts, err := os.ReadFile(filepath.Join(base, name))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.NewBuffer(bts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecuteWithTools(t *testing.T) {
|
||||||
|
p := filepath.Join("testdata", "tools")
|
||||||
|
cases := []struct {
|
||||||
|
model string
|
||||||
|
output string
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||||
|
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||||
|
|
||||||
|
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
|
||||||
|
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
||||||
|
|
||||||
|
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||||
|
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||||
|
{"command-r-plus", "Action: ```json" + `
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tool_name": "get_current_weather",
|
||||||
|
"parameters": {
|
||||||
|
"format": "fahrenheit",
|
||||||
|
"location": "San Francisco, CA"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tool_name": "get_current_weather",
|
||||||
|
"parameters": {
|
||||||
|
"format": "celsius",
|
||||||
|
"location": "Toronto, Canada"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
` + "```", true},
|
||||||
|
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||||
|
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||||
|
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
var tools []api.Tool
|
||||||
|
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var messages []api.Message
|
||||||
|
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
calls := []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: function{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"format": "fahrenheit",
|
||||||
|
"location": "San Francisco, CA",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Function: function{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"format": "celsius",
|
||||||
|
"location": "Toronto, Canada",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.model, func(t *testing.T) {
|
||||||
|
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("template", func(t *testing.T) {
|
||||||
|
var actual bytes.Buffer
|
||||||
|
if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("parse", func(t *testing.T) {
|
||||||
|
m := &Model{Template: tmpl}
|
||||||
|
actual, ok := m.parseToolCalls(tt.output)
|
||||||
|
if ok != tt.ok {
|
||||||
|
t.Fatalf("expected %t, got %t", tt.ok, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.ok {
|
||||||
|
if diff := cmp.Diff(actual, calls); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -103,18 +103,9 @@ func (mp ModelPath) GetShortTagname() string {
|
|||||||
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
|
||||||
}
|
}
|
||||||
|
|
||||||
// modelsDir returns the value of the OLLAMA_MODELS environment variable or the user's home directory if OLLAMA_MODELS is not set.
|
|
||||||
// The models directory is where Ollama stores its model files and manifests.
|
|
||||||
func modelsDir() (string, error) {
|
|
||||||
return envconfig.ModelsDir, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
|
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
|
||||||
func (mp ModelPath) GetManifestPath() (string, error) {
|
func (mp ModelPath) GetManifestPath() (string, error) {
|
||||||
dir, err := modelsDir()
|
dir := envconfig.ModelsDir
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
|
return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
|
||||||
}
|
}
|
||||||
@@ -127,10 +118,7 @@ func (mp ModelPath) BaseURL() *url.URL {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetManifestPath() (string, error) {
|
func GetManifestPath() (string, error) {
|
||||||
dir, err := modelsDir()
|
dir := envconfig.ModelsDir
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
path := filepath.Join(dir, "manifests")
|
path := filepath.Join(dir, "manifests")
|
||||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||||
@@ -141,10 +129,7 @@ func GetManifestPath() (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetBlobsPath(digest string) (string, error) {
|
func GetBlobsPath(digest string) (string, error) {
|
||||||
dir, err := modelsDir()
|
dir := envconfig.ModelsDir
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
// only accept actual sha256 digests
|
// only accept actual sha256 digests
|
||||||
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
|
||||||
|
245
server/prompt.go
245
server/prompt.go
@@ -1,221 +1,74 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"bytes"
|
||||||
|
"context"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strings"
|
|
||||||
"text/template"
|
|
||||||
"text/template/parse"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
// isResponseNode checks if the node contains .Response
|
type tokenizeFunc func(context.Context, string) ([]int, error)
|
||||||
func isResponseNode(node *parse.ActionNode) bool {
|
|
||||||
for _, cmd := range node.Pipe.Cmds {
|
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
||||||
for _, arg := range cmd.Args {
|
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||||
if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 {
|
// latest message and 2) system messages
|
||||||
if fieldNode.Ident[0] == "Response" {
|
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
|
||||||
return true
|
var system []api.Message
|
||||||
}
|
// always include the last message
|
||||||
|
n := len(msgs) - 1
|
||||||
|
// in reverse, find all messages that fit into context window
|
||||||
|
for i := n - 1; i >= 0; i-- {
|
||||||
|
system = make([]api.Message, 0)
|
||||||
|
for j := range i {
|
||||||
|
if msgs[j].Role == "system" {
|
||||||
|
system = append(system, msgs[j])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// formatTemplateForResponse formats the template AST to:
|
var b bytes.Buffer
|
||||||
// 1. remove all nodes after the first .Response (if generate=true)
|
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
|
||||||
// 2. add a .Response node to the end if it doesn't exist
|
return "", nil, err
|
||||||
// TODO(jmorganca): this should recursively cut the template before the first .Response
|
|
||||||
func formatTemplateForResponse(tmpl *template.Template, generate bool) {
|
|
||||||
var found bool
|
|
||||||
for i, node := range tmpl.Tree.Root.Nodes {
|
|
||||||
if actionNode, ok := node.(*parse.ActionNode); ok {
|
|
||||||
if isResponseNode(actionNode) {
|
|
||||||
found = true
|
|
||||||
if generate {
|
|
||||||
tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1]
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
s, err := tokenize(ctx, b.String())
|
||||||
// add the response node if it doesn't exist
|
|
||||||
responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}}
|
|
||||||
responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}}
|
|
||||||
responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode}
|
|
||||||
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prompt renders a prompt from a template. If generate is set to true,
|
|
||||||
// the response and parts of the template following it are not rendered
|
|
||||||
func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) {
|
|
||||||
parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
formatTemplateForResponse(parsed, generate)
|
|
||||||
|
|
||||||
vars := map[string]any{
|
|
||||||
"System": system,
|
|
||||||
"Prompt": prompt,
|
|
||||||
"Response": response,
|
|
||||||
}
|
|
||||||
|
|
||||||
var sb strings.Builder
|
|
||||||
if err := parsed.Execute(&sb, vars); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return sb.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
|
|
||||||
rendered, err := Prompt(tmpl, system, prompt, response, false)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
tokens, err := encode(rendered)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("failed to encode prompt", "err", err)
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(tokens), err
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
|
|
||||||
func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
|
|
||||||
type prompt struct {
|
|
||||||
System string
|
|
||||||
Prompt string
|
|
||||||
Response string
|
|
||||||
|
|
||||||
images []int
|
|
||||||
tokens int
|
|
||||||
}
|
|
||||||
|
|
||||||
var p prompt
|
|
||||||
|
|
||||||
// iterate through messages to build up {system,user,response} prompts
|
|
||||||
var imgId int
|
|
||||||
var prompts []prompt
|
|
||||||
for _, msg := range messages {
|
|
||||||
switch strings.ToLower(msg.Role) {
|
|
||||||
case "system":
|
|
||||||
if p.System != "" || p.Prompt != "" || p.Response != "" {
|
|
||||||
prompts = append(prompts, p)
|
|
||||||
p = prompt{}
|
|
||||||
}
|
|
||||||
|
|
||||||
p.System = msg.Content
|
|
||||||
case "user":
|
|
||||||
if p.Prompt != "" || p.Response != "" {
|
|
||||||
prompts = append(prompts, p)
|
|
||||||
p = prompt{}
|
|
||||||
}
|
|
||||||
|
|
||||||
var sb strings.Builder
|
|
||||||
for range msg.Images {
|
|
||||||
fmt.Fprintf(&sb, "[img-%d] ", imgId)
|
|
||||||
p.images = append(p.images, imgId)
|
|
||||||
imgId += 1
|
|
||||||
}
|
|
||||||
|
|
||||||
sb.WriteString(msg.Content)
|
|
||||||
p.Prompt = sb.String()
|
|
||||||
case "assistant":
|
|
||||||
if p.Response != "" {
|
|
||||||
prompts = append(prompts, p)
|
|
||||||
p = prompt{}
|
|
||||||
}
|
|
||||||
|
|
||||||
p.Response = msg.Content
|
|
||||||
default:
|
|
||||||
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// add final prompt
|
|
||||||
if p.System != "" || p.Prompt != "" || p.Response != "" {
|
|
||||||
prompts = append(prompts, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
// calculate token lengths for each prompt, estimating 768 tokens per images
|
|
||||||
for i, p := range prompts {
|
|
||||||
tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
prompts[i].tokens = tokens + len(prompts[i].images)*768
|
c := len(s)
|
||||||
}
|
if m.ProjectorPaths != nil {
|
||||||
|
for _, m := range msgs[i:] {
|
||||||
// truncate images and prompts starting from the beginning of the list
|
// images are represented as 768 sized embeddings
|
||||||
// until either one prompt remains or the total tokens fits the context window
|
// TODO: get embedding length from project metadata
|
||||||
// TODO (jmorganca): this doesn't account for the context window room required for the response
|
c += 768 * len(m.Images)
|
||||||
for {
|
}
|
||||||
var required int
|
|
||||||
for _, p := range prompts {
|
|
||||||
required += p.tokens
|
|
||||||
}
|
}
|
||||||
|
|
||||||
required += 1 // for bos token
|
if c > opts.NumCtx {
|
||||||
|
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
||||||
if required <= window {
|
|
||||||
slog.Debug("prompt now fits in context window", "required", required, "window", window)
|
|
||||||
break
|
break
|
||||||
|
} else {
|
||||||
|
n = i
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt := &prompts[0]
|
|
||||||
|
|
||||||
if len(prompt.images) > 1 {
|
|
||||||
img := prompt.images[0]
|
|
||||||
slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window)
|
|
||||||
prompt.images = prompt.images[1:]
|
|
||||||
prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1)
|
|
||||||
prompt.tokens -= 768
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(prompts) > 1 {
|
|
||||||
slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window)
|
|
||||||
system := prompt.System
|
|
||||||
prompts = prompts[1:]
|
|
||||||
|
|
||||||
if system != "" && prompts[0].System == "" {
|
|
||||||
prompts[0].System = system
|
|
||||||
|
|
||||||
tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
prompts[0].tokens = tokens + len(prompts[0].images)*768
|
|
||||||
}
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// stop truncating if there's only one prompt left
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var sb strings.Builder
|
// truncate any messages that do not fit into the context window
|
||||||
for i, p := range prompts {
|
var b bytes.Buffer
|
||||||
// last prompt should leave the response unrendered (for completion)
|
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
|
||||||
rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
|
return "", nil, err
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
sb.WriteString(rendered)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String(), nil
|
for _, m := range msgs[n:] {
|
||||||
|
for _, i := range m.Images {
|
||||||
|
images = append(images, llm.ImageData{
|
||||||
|
ID: len(images),
|
||||||
|
Data: i,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String(), images, nil
|
||||||
}
|
}
|
||||||
|
@@ -1,204 +1,209 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"bytes"
|
||||||
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPrompt(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
template string
|
|
||||||
system string
|
|
||||||
prompt string
|
|
||||||
response string
|
|
||||||
generate bool
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "simple prompt",
|
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
||||||
system: "You are a Wizard.",
|
|
||||||
prompt: "What are the potion ingredients?",
|
|
||||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "implicit response",
|
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
|
||||||
system: "You are a Wizard.",
|
|
||||||
prompt: "What are the potion ingredients?",
|
|
||||||
response: "I don't know.",
|
|
||||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "response",
|
|
||||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
|
|
||||||
system: "You are a Wizard.",
|
|
||||||
prompt: "What are the potion ingredients?",
|
|
||||||
response: "I don't know.",
|
|
||||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "cut",
|
|
||||||
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
|
|
||||||
system: "You are a Wizard.",
|
|
||||||
prompt: "What are the potion ingredients?",
|
|
||||||
response: "I don't know.",
|
|
||||||
generate: true,
|
|
||||||
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "nocut",
|
|
||||||
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
|
|
||||||
system: "You are a Wizard.",
|
|
||||||
prompt: "What are the potion ingredients?",
|
|
||||||
response: "I don't know.",
|
|
||||||
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.</assistant>",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got != tc.want {
|
|
||||||
t.Errorf("got = %v, want %v", got, tc.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChatPrompt(t *testing.T) {
|
func TestChatPrompt(t *testing.T) {
|
||||||
tests := []struct {
|
type expect struct {
|
||||||
name string
|
prompt string
|
||||||
template string
|
images [][]byte
|
||||||
messages []api.Message
|
}
|
||||||
window int
|
|
||||||
want string
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
limit int
|
||||||
|
msgs []api.Message
|
||||||
|
expect
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "simple prompt",
|
name: "messages",
|
||||||
template: "[INST] {{ .Prompt }} [/INST]",
|
limit: 64,
|
||||||
messages: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "user", Content: "Hello"},
|
{Role: "user", Content: "You're a test, Harry!"},
|
||||||
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||||
|
},
|
||||||
|
expect: expect{
|
||||||
|
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||||
},
|
},
|
||||||
window: 1024,
|
|
||||||
want: "[INST] Hello [/INST]",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "with system message",
|
name: "truncate messages",
|
||||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
|
limit: 1,
|
||||||
messages: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "system", Content: "You are a Wizard."},
|
{Role: "user", Content: "You're a test, Harry!"},
|
||||||
{Role: "user", Content: "Hello"},
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||||
|
},
|
||||||
|
expect: expect{
|
||||||
|
prompt: "A test. And a thumping good one at that, I'd wager. ",
|
||||||
},
|
},
|
||||||
window: 1024,
|
|
||||||
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "with response",
|
name: "truncate messages with image",
|
||||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}",
|
limit: 64,
|
||||||
messages: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "system", Content: "You are a Wizard."},
|
{Role: "user", Content: "You're a test, Harry!"},
|
||||||
{Role: "user", Content: "Hello"},
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
{Role: "assistant", Content: "I am?"},
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
|
||||||
|
},
|
||||||
|
expect: expect{
|
||||||
|
prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
|
||||||
|
images: [][]byte{
|
||||||
|
[]byte("something"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
window: 1024,
|
|
||||||
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST] I am?",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "with implicit response",
|
name: "truncate messages with images",
|
||||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
|
limit: 64,
|
||||||
messages: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "system", Content: "You are a Wizard."},
|
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
|
||||||
{Role: "user", Content: "Hello"},
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
{Role: "assistant", Content: "I am?"},
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||||
|
},
|
||||||
|
expect: expect{
|
||||||
|
prompt: "[img-0] A test. And a thumping good one at that, I'd wager. ",
|
||||||
|
images: [][]byte{
|
||||||
|
[]byte("somethingelse"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
window: 1024,
|
|
||||||
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]I am?",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "with conversation",
|
name: "messages with images",
|
||||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
|
limit: 2048,
|
||||||
messages: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "system", Content: "You are a Wizard."},
|
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
|
||||||
{Role: "user", Content: "What are the potion ingredients?"},
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
{Role: "assistant", Content: "sugar"},
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||||
{Role: "user", Content: "Anything else?"},
|
},
|
||||||
|
expect: expect{
|
||||||
|
prompt: "[img-0] You're a test, Harry! I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
|
||||||
|
images: [][]byte{
|
||||||
|
[]byte("something"),
|
||||||
|
[]byte("somethingelse"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
window: 1024,
|
|
||||||
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "with truncation",
|
name: "message with image tag",
|
||||||
template: "{{ .System }} {{ .Prompt }} {{ .Response }} ",
|
limit: 2048,
|
||||||
messages: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "system", Content: "You are a Wizard."},
|
{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
|
||||||
{Role: "user", Content: "Hello"},
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
{Role: "assistant", Content: "I am?"},
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||||
{Role: "user", Content: "Why is the sky blue?"},
|
},
|
||||||
{Role: "assistant", Content: "The sky is blue from rayleigh scattering"},
|
expect: expect{
|
||||||
|
prompt: "You're a test, Harry! [img-0] I-I'm a what? [img-1] A test. And a thumping good one at that, I'd wager. ",
|
||||||
|
images: [][]byte{
|
||||||
|
[]byte("something"),
|
||||||
|
[]byte("somethingelse"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
window: 10,
|
|
||||||
want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "images",
|
name: "messages with interleaved images",
|
||||||
template: "{{ .System }} {{ .Prompt }}",
|
limit: 2048,
|
||||||
messages: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "system", Content: "You are a Wizard."},
|
{Role: "user", Content: "You're a test, Harry!"},
|
||||||
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
|
{Role: "user", Images: []api.ImageData{[]byte("something")}},
|
||||||
|
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||||
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||||
|
},
|
||||||
|
expect: expect{
|
||||||
|
prompt: "You're a test, Harry!\n\n[img-0]\n\n[img-1] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||||
|
images: [][]byte{
|
||||||
|
[]byte("something"),
|
||||||
|
[]byte("somethingelse"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
window: 1024,
|
|
||||||
want: "You are a Wizard. [img-0] Hello",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "images truncated",
|
name: "truncate message with interleaved images",
|
||||||
template: "{{ .System }} {{ .Prompt }}",
|
limit: 1024,
|
||||||
messages: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "system", Content: "You are a Wizard."},
|
{Role: "user", Content: "You're a test, Harry!"},
|
||||||
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
|
{Role: "user", Images: []api.ImageData{[]byte("something")}},
|
||||||
|
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||||
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||||
|
},
|
||||||
|
expect: expect{
|
||||||
|
prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||||
|
images: [][]byte{
|
||||||
|
[]byte("somethingelse"),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
window: 1024,
|
|
||||||
want: "You are a Wizard. [img-0] [img-1] Hello",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty list",
|
name: "message with system prompt",
|
||||||
template: "{{ .System }} {{ .Prompt }}",
|
limit: 2048,
|
||||||
messages: []api.Message{},
|
msgs: []api.Message{
|
||||||
window: 1024,
|
{Role: "system", Content: "You are the Test Who Lived."},
|
||||||
want: "",
|
{Role: "user", Content: "You're a test, Harry!"},
|
||||||
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||||
|
},
|
||||||
|
expect: expect{
|
||||||
|
prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty prompt",
|
name: "out of order system",
|
||||||
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
|
limit: 2048,
|
||||||
messages: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "user", Content: ""},
|
{Role: "user", Content: "You're a test, Harry!"},
|
||||||
|
{Role: "assistant", Content: "I-I'm a what?"},
|
||||||
|
{Role: "system", Content: "You are the Test Who Lived."},
|
||||||
|
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||||
|
},
|
||||||
|
expect: expect{
|
||||||
|
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
|
||||||
},
|
},
|
||||||
window: 1024,
|
|
||||||
want: "",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
encode := func(s string) ([]int, error) {
|
tmpl, err := template.Parse(`
|
||||||
words := strings.Fields(s)
|
{{- if .System }}{{ .System }} {{ end }}
|
||||||
return make([]int, len(words)), nil
|
{{- if .Prompt }}{{ .Prompt }} {{ end }}
|
||||||
|
{{- if .Response }}{{ .Response }} {{ end }}`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tt := range cases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
|
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
|
||||||
|
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||||
|
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("error = %v", err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if got != tc.want {
|
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
|
||||||
t.Errorf("got: %q, want: %q", got, tc.want)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(images) != len(tt.images) {
|
||||||
|
t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range images {
|
||||||
|
if images[i].ID != i {
|
||||||
|
t.Errorf("expected ID %d, got %d", i, images[i].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(images[i].Data, tt.images[i]) {
|
||||||
|
t.Errorf("expected %q, got %q", tt.images[i], images[i])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
693
server/routes.go
693
server/routes.go
@@ -1,13 +1,13 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"cmp"
|
"cmp"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
@@ -17,7 +17,6 @@ import (
|
|||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
@@ -31,6 +30,7 @@ import (
|
|||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
@@ -55,7 +55,7 @@ func init() {
|
|||||||
gin.SetMode(mode)
|
gin.SetMode(mode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultSessionDuration = 5 * time.Minute
|
var errRequired = errors.New("is required")
|
||||||
|
|
||||||
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
@@ -70,277 +70,225 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
|
|||||||
return opts, nil
|
return opts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isSupportedImageType(image []byte) bool {
|
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
|
||||||
contentType := http.DetectContentType(image)
|
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
|
||||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
|
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
|
||||||
return slices.Contains(allowedTypes, contentType)
|
if name == "" {
|
||||||
|
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
|
||||||
|
}
|
||||||
|
|
||||||
|
model, err := GetModel(name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := model.CheckCapabilities(caps...); err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
opts, err := modelOptions(model, requestOpts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
|
||||||
|
var runner *runnerRef
|
||||||
|
select {
|
||||||
|
case runner = <-runnerCh:
|
||||||
|
case err = <-errCh:
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return runner.llama, model, &opts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) GenerateHandler(c *gin.Context) {
|
func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
checkpointStart := time.Now()
|
checkpointStart := time.Now()
|
||||||
var req api.GenerateRequest
|
var req api.GenerateRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||||
|
|
||||||
switch {
|
|
||||||
case errors.Is(err, io.EOF):
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
return
|
return
|
||||||
case err != nil:
|
} else if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate the request
|
if req.Format != "" && req.Format != "json" {
|
||||||
switch {
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
|
||||||
case req.Model == "":
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
||||||
return
|
return
|
||||||
case len(req.Format) > 0 && req.Format != "json":
|
} else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
|
|
||||||
return
|
|
||||||
case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, img := range req.Images {
|
caps := []Capability{CapabilityCompletion}
|
||||||
if !isSupportedImageType(img) {
|
if req.Suffix != "" {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
|
caps = append(caps, CapabilityInsert)
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
model, err := GetModel(req.Model)
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||||
if err != nil {
|
if errors.Is(err, errCapabilityCompletion) {
|
||||||
var pErr *fs.PathError
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
||||||
if errors.As(err, &pErr) {
|
return
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
} else if err != nil {
|
||||||
return
|
handleScheduleError(c, req.Model, err)
|
||||||
}
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if model.IsEmbedding() {
|
checkpointLoaded := time.Now()
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
opts, err := modelOptions(model, req.Options)
|
if req.Prompt == "" {
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var sessionDuration time.Duration
|
|
||||||
if req.KeepAlive == nil {
|
|
||||||
sessionDuration = getDefaultSessionDuration()
|
|
||||||
} else {
|
|
||||||
sessionDuration = req.KeepAlive.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
|
|
||||||
var runner *runnerRef
|
|
||||||
select {
|
|
||||||
case runner = <-rCh:
|
|
||||||
case err = <-eCh:
|
|
||||||
handleErrorResponse(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// an empty request loads the model
|
|
||||||
// note: for a short while template was used in lieu
|
|
||||||
// of `raw` mode so we need to check for it too
|
|
||||||
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
|
||||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||||
CreatedAt: time.Now().UTC(),
|
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
Done: true,
|
Done: true,
|
||||||
DoneReason: "load",
|
DoneReason: "load",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
checkpointLoaded := time.Now()
|
images := make([]llm.ImageData, len(req.Images))
|
||||||
|
for i := range req.Images {
|
||||||
|
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
|
||||||
|
}
|
||||||
|
|
||||||
var prompt string
|
prompt := req.Prompt
|
||||||
switch {
|
if !req.Raw {
|
||||||
case req.Raw:
|
tmpl := m.Template
|
||||||
prompt = req.Prompt
|
if req.Template != "" {
|
||||||
case req.Prompt != "":
|
tmpl, err = template.Parse(req.Template)
|
||||||
if req.Template == "" {
|
if err != nil {
|
||||||
req.Template = model.Template
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.System == "" {
|
var b bytes.Buffer
|
||||||
req.System = model.System
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Debug("generate handler", "prompt", req.Prompt)
|
|
||||||
slog.Debug("generate handler", "template", req.Template)
|
|
||||||
slog.Debug("generate handler", "system", req.System)
|
|
||||||
|
|
||||||
var sb strings.Builder
|
|
||||||
for i := range req.Images {
|
|
||||||
fmt.Fprintf(&sb, "[img-%d] ", i)
|
|
||||||
}
|
|
||||||
|
|
||||||
sb.WriteString(req.Prompt)
|
|
||||||
|
|
||||||
p, err := Prompt(req.Template, req.System, sb.String(), "", true)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sb.Reset()
|
|
||||||
if req.Context != nil {
|
if req.Context != nil {
|
||||||
prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
|
s, err := r.Detokenize(c.Request.Context(), req.Context)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString(prev)
|
b.WriteString(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString(p)
|
var values template.Values
|
||||||
|
if req.Suffix != "" {
|
||||||
prompt = sb.String()
|
values.Prompt = prompt
|
||||||
}
|
values.Suffix = req.Suffix
|
||||||
|
} else {
|
||||||
slog.Debug("generate handler", "prompt", prompt)
|
var msgs []api.Message
|
||||||
|
if req.System != "" {
|
||||||
ch := make(chan any)
|
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
|
||||||
var generated strings.Builder
|
} else if m.System != "" {
|
||||||
go func() {
|
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
|
||||||
defer close(ch)
|
|
||||||
|
|
||||||
fn := func(r llm.CompletionResponse) {
|
|
||||||
// Build up the full response
|
|
||||||
if _, err := generated.WriteString(r.Content); err != nil {
|
|
||||||
ch <- gin.H{"error": err.Error()}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := api.GenerateResponse{
|
for _, i := range images {
|
||||||
|
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
|
||||||
|
}
|
||||||
|
|
||||||
|
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tmpl.Execute(&b, values); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("generate request", "prompt", prompt, "images", images)
|
||||||
|
|
||||||
|
ch := make(chan any)
|
||||||
|
go func() {
|
||||||
|
// TODO (jmorganca): avoid building the response twice both here and below
|
||||||
|
var sb strings.Builder
|
||||||
|
defer close(ch)
|
||||||
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
|
Prompt: prompt,
|
||||||
|
Images: images,
|
||||||
|
Format: req.Format,
|
||||||
|
Options: opts,
|
||||||
|
}, func(cr llm.CompletionResponse) {
|
||||||
|
res := api.GenerateResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
Done: r.Done,
|
Response: cr.Content,
|
||||||
Response: r.Content,
|
Done: cr.Done,
|
||||||
DoneReason: r.DoneReason,
|
DoneReason: cr.DoneReason,
|
||||||
Metrics: api.Metrics{
|
Metrics: api.Metrics{
|
||||||
PromptEvalCount: r.PromptEvalCount,
|
PromptEvalCount: cr.PromptEvalCount,
|
||||||
PromptEvalDuration: r.PromptEvalDuration,
|
PromptEvalDuration: cr.PromptEvalDuration,
|
||||||
EvalCount: r.EvalCount,
|
EvalCount: cr.EvalCount,
|
||||||
EvalDuration: r.EvalDuration,
|
EvalDuration: cr.EvalDuration,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Done {
|
if _, err := sb.WriteString(cr.Content); err != nil {
|
||||||
resp.TotalDuration = time.Since(checkpointStart)
|
ch <- gin.H{"error": err.Error()}
|
||||||
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
}
|
||||||
|
|
||||||
|
if cr.Done {
|
||||||
|
res.TotalDuration = time.Since(checkpointStart)
|
||||||
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
|
|
||||||
if !req.Raw {
|
if !req.Raw {
|
||||||
p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
|
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO (jmorganca): encode() should not strip special tokens
|
|
||||||
tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
res.Context = append(req.Context, tokens...)
|
||||||
resp.Context = append(req.Context, tokens...)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ch <- resp
|
ch <- res
|
||||||
}
|
}); err != nil {
|
||||||
|
|
||||||
var images []llm.ImageData
|
|
||||||
for i := range req.Images {
|
|
||||||
images = append(images, llm.ImageData{
|
|
||||||
ID: i,
|
|
||||||
Data: req.Images[i],
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start prediction
|
|
||||||
req := llm.CompletionRequest{
|
|
||||||
Prompt: prompt,
|
|
||||||
Format: req.Format,
|
|
||||||
Images: images,
|
|
||||||
Options: opts,
|
|
||||||
}
|
|
||||||
if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
|
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if req.Stream != nil && !*req.Stream {
|
if req.Stream != nil && !*req.Stream {
|
||||||
// Accumulate responses into the final response
|
var r api.GenerateResponse
|
||||||
var final api.GenerateResponse
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for resp := range ch {
|
for rr := range ch {
|
||||||
switch r := resp.(type) {
|
switch t := rr.(type) {
|
||||||
case api.GenerateResponse:
|
case api.GenerateResponse:
|
||||||
sb.WriteString(r.Response)
|
sb.WriteString(t.Response)
|
||||||
final = r
|
r = t
|
||||||
case gin.H:
|
case gin.H:
|
||||||
if errorMsg, ok := r["error"].(string); ok {
|
msg, ok := t["error"].(string)
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
if !ok {
|
||||||
return
|
msg = "unexpected error format in response"
|
||||||
} else {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
|
||||||
|
return
|
||||||
default:
|
default:
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
final.Response = sb.String()
|
r.Response = sb.String()
|
||||||
c.JSON(http.StatusOK, final)
|
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||||
|
r.ToolCalls = toolCalls
|
||||||
|
r.Response = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDefaultSessionDuration() time.Duration {
|
func (s *Server) EmbedHandler(c *gin.Context) {
|
||||||
if envconfig.KeepAlive != "" {
|
var req api.EmbedRequest
|
||||||
v, err := strconv.Atoi(envconfig.KeepAlive)
|
|
||||||
if err != nil {
|
|
||||||
d, err := time.ParseDuration(envconfig.KeepAlive)
|
|
||||||
if err != nil {
|
|
||||||
return defaultSessionDuration
|
|
||||||
}
|
|
||||||
|
|
||||||
if d < 0 {
|
|
||||||
return time.Duration(math.MaxInt64)
|
|
||||||
}
|
|
||||||
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
||||||
d := time.Duration(v) * time.Second
|
|
||||||
if d < 0 {
|
|
||||||
return time.Duration(math.MaxInt64)
|
|
||||||
}
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
||||||
return defaultSessionDuration
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|
||||||
var req api.EmbeddingRequest
|
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, io.EOF):
|
case errors.Is(err, io.EOF):
|
||||||
@@ -351,41 +299,122 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Model == "" {
|
truncate := true
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
||||||
|
if req.Truncate != nil && !*req.Truncate {
|
||||||
|
truncate = false
|
||||||
|
}
|
||||||
|
|
||||||
|
var input []string
|
||||||
|
|
||||||
|
switch i := req.Input.(type) {
|
||||||
|
case string:
|
||||||
|
if len(i) > 0 {
|
||||||
|
input = append(input, i)
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for _, v := range i {
|
||||||
|
if _, ok := v.(string); !ok {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
input = append(input, v.(string))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
model, err := GetModel(req.Model)
|
if len(input) == 0 {
|
||||||
|
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var pErr *fs.PathError
|
handleScheduleError(c, req.Model, err)
|
||||||
if errors.As(err, &pErr) {
|
return
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
}
|
||||||
|
|
||||||
|
kvData, err := getKVData(m.ModelPath, false)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, s := range input {
|
||||||
|
tokens, err := r.Tokenize(c.Request.Context(), s)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
opts, err := modelOptions(model, req.Options)
|
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
||||||
|
if len(tokens) > ctxLen {
|
||||||
|
if !truncate {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens = tokens[:ctxLen]
|
||||||
|
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
input[i] = s
|
||||||
|
}
|
||||||
|
embeddings, err := r.Embed(c.Request.Context(), input)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
slog.Error("embedding generation failed", "error", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var sessionDuration time.Duration
|
for i, e := range embeddings {
|
||||||
if req.KeepAlive == nil {
|
embeddings[i] = normalize(e)
|
||||||
sessionDuration = getDefaultSessionDuration()
|
|
||||||
} else {
|
|
||||||
sessionDuration = req.KeepAlive.Duration
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
|
resp := api.EmbedResponse{
|
||||||
var runner *runnerRef
|
Model: req.Model,
|
||||||
select {
|
Embeddings: embeddings,
|
||||||
case runner = <-rCh:
|
}
|
||||||
case err = <-eCh:
|
c.JSON(http.StatusOK, resp)
|
||||||
handleErrorResponse(c, err)
|
}
|
||||||
|
|
||||||
|
func normalize(vec []float32) []float32 {
|
||||||
|
var sum float32
|
||||||
|
for _, v := range vec {
|
||||||
|
sum += v * v
|
||||||
|
}
|
||||||
|
|
||||||
|
norm := float32(0.0)
|
||||||
|
if sum > 0 {
|
||||||
|
norm = float32(1.0 / math.Sqrt(float64(sum)))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range vec {
|
||||||
|
vec[i] *= norm
|
||||||
|
}
|
||||||
|
return vec
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||||
|
var req api.EmbeddingRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
} else if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||||
|
if err != nil {
|
||||||
|
handleScheduleError(c, req.Model, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -395,13 +424,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
|
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
embedding := make([]float64, len(embeddings[0]))
|
||||||
|
|
||||||
|
for i, v := range embeddings[0] {
|
||||||
|
embedding[i] = float64(v)
|
||||||
|
}
|
||||||
|
|
||||||
resp := api.EmbeddingResponse{
|
resp := api.EmbeddingResponse{
|
||||||
Embedding: embedding,
|
Embedding: embedding,
|
||||||
}
|
}
|
||||||
@@ -679,13 +715,9 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
m.System = req.System
|
m.System = req.System
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Template != "" {
|
msgs := make([]api.Message, len(m.Messages))
|
||||||
m.Template = req.Template
|
for i, msg := range m.Messages {
|
||||||
}
|
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
|
||||||
|
|
||||||
msgs := make([]api.Message, 0)
|
|
||||||
for _, msg := range m.Messages {
|
|
||||||
msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
n := model.ParseName(req.Model)
|
n := model.ParseName(req.Model)
|
||||||
@@ -701,7 +733,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
resp := &api.ShowResponse{
|
resp := &api.ShowResponse{
|
||||||
License: strings.Join(m.License, "\n"),
|
License: strings.Join(m.License, "\n"),
|
||||||
System: m.System,
|
System: m.System,
|
||||||
Template: m.Template,
|
Template: m.Template.String(),
|
||||||
Details: modelDetails,
|
Details: modelDetails,
|
||||||
Messages: msgs,
|
Messages: msgs,
|
||||||
ModifiedAt: manifest.fi.ModTime(),
|
ModifiedAt: manifest.fi.ModTime(),
|
||||||
@@ -1028,6 +1060,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
|||||||
r.POST("/api/pull", s.PullModelHandler)
|
r.POST("/api/pull", s.PullModelHandler)
|
||||||
r.POST("/api/generate", s.GenerateHandler)
|
r.POST("/api/generate", s.GenerateHandler)
|
||||||
r.POST("/api/chat", s.ChatHandler)
|
r.POST("/api/chat", s.ChatHandler)
|
||||||
|
r.POST("/api/embed", s.EmbedHandler)
|
||||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||||
r.POST("/api/create", s.CreateModelHandler)
|
r.POST("/api/create", s.CreateModelHandler)
|
||||||
r.POST("/api/push", s.PushModelHandler)
|
r.POST("/api/push", s.PushModelHandler)
|
||||||
@@ -1039,7 +1072,11 @@ func (s *Server) GenerateRoutes() http.Handler {
|
|||||||
r.GET("/api/ps", s.ProcessHandler)
|
r.GET("/api/ps", s.ProcessHandler)
|
||||||
|
|
||||||
// Compatibility endpoints
|
// Compatibility endpoints
|
||||||
r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
|
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
||||||
|
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
|
||||||
|
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||||
|
r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
|
||||||
|
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
|
||||||
|
|
||||||
for _, method := range []string{http.MethodGet, http.MethodHead} {
|
for _, method := range []string{http.MethodGet, http.MethodHead} {
|
||||||
r.Handle(method, "/", func(c *gin.Context) {
|
r.Handle(method, "/", func(c *gin.Context) {
|
||||||
@@ -1245,139 +1282,67 @@ func (s *Server) ProcessHandler(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
|
||||||
func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
|
|
||||||
encode := func(s string) ([]int, error) {
|
|
||||||
return runner.llama.Tokenize(ctx, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt, err := ChatPrompt(template, messages, numCtx, encode)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return prompt, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) ChatHandler(c *gin.Context) {
|
func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
checkpointStart := time.Now()
|
checkpointStart := time.Now()
|
||||||
|
|
||||||
var req api.ChatRequest
|
var req api.ChatRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||||
switch {
|
|
||||||
case errors.Is(err, io.EOF):
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
return
|
return
|
||||||
case err != nil:
|
} else if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate the request
|
caps := []Capability{CapabilityCompletion}
|
||||||
switch {
|
if req.Tools != nil {
|
||||||
case req.Model == "":
|
caps = append(caps, CapabilityTools)
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
||||||
return
|
|
||||||
case len(req.Format) > 0 && req.Format != "json":
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
model, err := GetModel(req.Model)
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||||
if err != nil {
|
if errors.Is(err, errCapabilityCompletion) {
|
||||||
var pErr *fs.PathError
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
||||||
if errors.As(err, &pErr) {
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
return
|
||||||
}
|
} else if err != nil {
|
||||||
|
handleScheduleError(c, req.Model, err)
|
||||||
if model.IsEmbedding() {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
opts, err := modelOptions(model, req.Options)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var sessionDuration time.Duration
|
|
||||||
if req.KeepAlive == nil {
|
|
||||||
sessionDuration = getDefaultSessionDuration()
|
|
||||||
} else {
|
|
||||||
sessionDuration = req.KeepAlive.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
|
|
||||||
var runner *runnerRef
|
|
||||||
select {
|
|
||||||
case runner = <-rCh:
|
|
||||||
case err = <-eCh:
|
|
||||||
handleErrorResponse(c, err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
checkpointLoaded := time.Now()
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
// if the first message is not a system message, then add the model's default system message
|
if len(req.Messages) == 0 {
|
||||||
if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
|
c.JSON(http.StatusOK, api.ChatResponse{
|
||||||
req.Messages = append([]api.Message{
|
|
||||||
{
|
|
||||||
Role: "system",
|
|
||||||
Content: model.System,
|
|
||||||
},
|
|
||||||
}, req.Messages...)
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// an empty request loads the model
|
|
||||||
if len(req.Messages) == 0 || prompt == "" {
|
|
||||||
resp := api.ChatResponse{
|
|
||||||
CreatedAt: time.Now().UTC(),
|
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
Message: api.Message{Role: "assistant"},
|
||||||
Done: true,
|
Done: true,
|
||||||
DoneReason: "load",
|
DoneReason: "load",
|
||||||
Message: api.Message{Role: "assistant"},
|
})
|
||||||
}
|
|
||||||
c.JSON(http.StatusOK, resp)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// only send images that are in the prompt
|
if req.Messages[0].Role != "system" && m.System != "" {
|
||||||
var i int
|
req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...)
|
||||||
var images []llm.ImageData
|
|
||||||
for _, m := range req.Messages {
|
|
||||||
for _, img := range m.Images {
|
|
||||||
if !isSupportedImageType(img) {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
|
|
||||||
images = append(images, llm.ImageData{Data: img, ID: i})
|
|
||||||
}
|
|
||||||
i += 1
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("chat handler", "prompt", prompt, "images", len(images))
|
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("chat request", "images", len(images), "prompt", prompt)
|
||||||
|
|
||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
fn := func(r llm.CompletionResponse) {
|
Prompt: prompt,
|
||||||
resp := api.ChatResponse{
|
Images: images,
|
||||||
|
Format: req.Format,
|
||||||
|
Options: opts,
|
||||||
|
}, func(r llm.CompletionResponse) {
|
||||||
|
res := api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||||
@@ -1392,62 +1357,62 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if r.Done {
|
if r.Done {
|
||||||
resp.TotalDuration = time.Since(checkpointStart)
|
res.TotalDuration = time.Since(checkpointStart)
|
||||||
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
ch <- resp
|
ch <- res
|
||||||
}
|
}); err != nil {
|
||||||
|
|
||||||
if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
|
||||||
Prompt: prompt,
|
|
||||||
Format: req.Format,
|
|
||||||
Images: images,
|
|
||||||
Options: opts,
|
|
||||||
}, fn); err != nil {
|
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if req.Stream != nil && !*req.Stream {
|
if req.Stream != nil && !*req.Stream {
|
||||||
// Accumulate responses into the final response
|
var resp api.ChatResponse
|
||||||
var final api.ChatResponse
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for resp := range ch {
|
for rr := range ch {
|
||||||
switch r := resp.(type) {
|
switch t := rr.(type) {
|
||||||
case api.ChatResponse:
|
case api.ChatResponse:
|
||||||
sb.WriteString(r.Message.Content)
|
sb.WriteString(t.Message.Content)
|
||||||
final = r
|
resp = t
|
||||||
case gin.H:
|
case gin.H:
|
||||||
if errorMsg, ok := r["error"].(string); ok {
|
msg, ok := t["error"].(string)
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
if !ok {
|
||||||
return
|
msg = "unexpected error format in response"
|
||||||
} else {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
|
||||||
|
return
|
||||||
default:
|
default:
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
final.Message = api.Message{Role: "assistant", Content: sb.String()}
|
resp.Message.Content = sb.String()
|
||||||
c.JSON(http.StatusOK, final)
|
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||||
|
resp.Message.ToolCalls = toolCalls
|
||||||
|
resp.Message.Content = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, resp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleErrorResponse(c *gin.Context, err error) {
|
func handleScheduleError(c *gin.Context, name string, err error) {
|
||||||
if errors.Is(err, context.Canceled) {
|
switch {
|
||||||
|
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
case errors.Is(err, context.Canceled):
|
||||||
c.JSON(499, gin.H{"error": "request canceled"})
|
c.JSON(499, gin.H{"error": "request canceled"})
|
||||||
return
|
case errors.Is(err, ErrMaxQueue):
|
||||||
}
|
|
||||||
if errors.Is(err, ErrMaxQueue) {
|
|
||||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
|
||||||
return
|
case errors.Is(err, os.ErrNotExist):
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
|
||||||
|
default:
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
}
|
}
|
||||||
|
@@ -85,6 +85,8 @@ func checkFileExists(t *testing.T, p string, expect []string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateFromBin(t *testing.T) {
|
func TestCreateFromBin(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
@@ -111,6 +113,8 @@ func TestCreateFromBin(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateFromModel(t *testing.T) {
|
func TestCreateFromModel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
@@ -152,6 +156,8 @@ func TestCreateFromModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateRemovesLayers(t *testing.T) {
|
func TestCreateRemovesLayers(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
@@ -199,6 +205,8 @@ func TestCreateRemovesLayers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateUnsetsSystem(t *testing.T) {
|
func TestCreateUnsetsSystem(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
@@ -255,6 +263,8 @@ func TestCreateUnsetsSystem(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateMergeParameters(t *testing.T) {
|
func TestCreateMergeParameters(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
@@ -358,6 +368,8 @@ func TestCreateMergeParameters(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateReplacesMessages(t *testing.T) {
|
func TestCreateReplacesMessages(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
@@ -434,6 +446,8 @@ func TestCreateReplacesMessages(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateTemplateSystem(t *testing.T) {
|
func TestCreateTemplateSystem(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
@@ -480,6 +494,8 @@ func TestCreateTemplateSystem(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateLicenses(t *testing.T) {
|
func TestCreateLicenses(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
@@ -526,6 +542,8 @@ func TestCreateLicenses(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateDetectTemplate(t *testing.T) {
|
func TestCreateDetectTemplate(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
@@ -545,9 +563,9 @@ func TestCreateDetectTemplate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
filepath.Join(p, "blobs", "sha256-2f8e594e6f34b1b4d36a246628eeb3365ce442303d656f1fcc69e821722acea0"),
|
|
||||||
filepath.Join(p, "blobs", "sha256-542b217f179c7825eeb5bca3c77d2b75ed05bafbd3451d9188891a60a85337c6"),
|
|
||||||
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
|
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-c608dc615584cd20d9d830363dabf8a4783ae5d34245c3d8c115edb3bc7b28e4"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-f836ee110db21567f826332e4cedd746c06d10664fd5a9ea3659e3683a944510"),
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@@ -8,12 +8,15 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDelete(t *testing.T) {
|
func TestDelete(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
@@ -77,6 +80,8 @@ func TestDelete(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteDuplicateLayers(t *testing.T) {
|
func TestDeleteDuplicateLayers(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
var s Server
|
var s Server
|
||||||
|
712
server/routes_generate_test.go
Normal file
712
server/routes_generate_test.go
Normal file
@@ -0,0 +1,712 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/gpu"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockRunner struct {
|
||||||
|
llm.LlamaServer
|
||||||
|
|
||||||
|
// CompletionRequest is only valid until the next call to Completion
|
||||||
|
llm.CompletionRequest
|
||||||
|
llm.CompletionResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||||
|
m.CompletionRequest = r
|
||||||
|
fn(m.CompletionResponse)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
|
||||||
|
for range strings.Fields(s) {
|
||||||
|
tokens = append(tokens, len(tokens))
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockServer(mock *mockRunner) func(gpu.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
|
||||||
|
return func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||||
|
return mock, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateChat(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
mock := mockRunner{
|
||||||
|
CompletionResponse: llm.CompletionResponse{
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
PromptEvalCount: 1,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
EvalCount: 1,
|
||||||
|
EvalDuration: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := Server{
|
||||||
|
sched: &Scheduler{
|
||||||
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
expiredCh: make(chan *runnerRef, 1),
|
||||||
|
unloadedCh: make(chan any, 1),
|
||||||
|
loaded: make(map[string]*runnerRef),
|
||||||
|
newServerFn: newMockServer(&mock),
|
||||||
|
getGpuFn: gpu.GetGPUInfo,
|
||||||
|
getCpuFn: gpu.GetCPUInfo,
|
||||||
|
reschedDelay: 250 * time.Millisecond,
|
||||||
|
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
||||||
|
// add 10ms delay to simulate loading
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
req.successCh <- &runnerRef{
|
||||||
|
llama: &mock,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.sched.Run(context.TODO())
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Modelfile: fmt.Sprintf(`FROM %s
|
||||||
|
TEMPLATE """
|
||||||
|
{{- if .System }}System: {{ .System }} {{ end }}
|
||||||
|
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
|
||||||
|
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
|
||||||
|
`, createBinFile(t, llm.KV{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"llama.block_count": uint32(1),
|
||||||
|
"llama.context_length": uint32(8192),
|
||||||
|
"llama.embedding_length": uint32(4096),
|
||||||
|
"llama.attention.head_count": uint32(32),
|
||||||
|
"llama.attention.head_count_kv": uint32(8),
|
||||||
|
"tokenizer.ggml.tokens": []string{""},
|
||||||
|
"tokenizer.ggml.scores": []float32{0},
|
||||||
|
"tokenizer.ggml.token_type": []int32{0},
|
||||||
|
}, []llm.Tensor{
|
||||||
|
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
})),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("missing body", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, nil)
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing model", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing capabilities chat", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "bert",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
|
||||||
|
"general.architecture": "bert",
|
||||||
|
"bert.pooling_type": uint32(0),
|
||||||
|
}, []llm.Tensor{})),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "bert",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("load model", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var actual api.ChatResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Model != "test" {
|
||||||
|
t.Errorf("expected model test, got %s", actual.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !actual.Done {
|
||||||
|
t.Errorf("expected done true, got false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.DoneReason != "load" {
|
||||||
|
t.Errorf("expected done reason load, got %s", actual.DoneReason)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var actual api.ChatResponse
|
||||||
|
if err := json.NewDecoder(body).Decode(&actual); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Model != model {
|
||||||
|
t.Errorf("expected model test, got %s", actual.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !actual.Done {
|
||||||
|
t.Errorf("expected done false, got true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.DoneReason != "stop" {
|
||||||
|
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(actual.Message, api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: content,
|
||||||
|
}); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.PromptEvalCount == 0 {
|
||||||
|
t.Errorf("expected prompt eval count > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.PromptEvalDuration == 0 {
|
||||||
|
t.Errorf("expected prompt eval duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.EvalCount == 0 {
|
||||||
|
t.Errorf("expected eval count > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.EvalDuration == 0 {
|
||||||
|
t.Errorf("expected eval duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.LoadDuration == 0 {
|
||||||
|
t.Errorf("expected load duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.TotalDuration == 0 {
|
||||||
|
t.Errorf("expected total duration > 0, got 0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock.CompletionResponse.Content = "Hi!"
|
||||||
|
t.Run("messages", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkChatResponse(t, w.Body, "test", "Hi!")
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("messages with model system", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkChatResponse(t, w.Body, "test-system", "Hi!")
|
||||||
|
})
|
||||||
|
|
||||||
|
mock.CompletionResponse.Content = "Abra kadabra!"
|
||||||
|
t.Run("messages with system", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "system", Content: "You can perform magic tricks."},
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("messages with interleaved system", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
{Role: "assistant", Content: "I can help you with that."},
|
||||||
|
{Role: "system", Content: "You can perform magic tricks."},
|
||||||
|
{Role: "user", Content: "Help me write tests."},
|
||||||
|
},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerate(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
mock := mockRunner{
|
||||||
|
CompletionResponse: llm.CompletionResponse{
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
PromptEvalCount: 1,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
EvalCount: 1,
|
||||||
|
EvalDuration: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := Server{
|
||||||
|
sched: &Scheduler{
|
||||||
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
expiredCh: make(chan *runnerRef, 1),
|
||||||
|
unloadedCh: make(chan any, 1),
|
||||||
|
loaded: make(map[string]*runnerRef),
|
||||||
|
newServerFn: newMockServer(&mock),
|
||||||
|
getGpuFn: gpu.GetGPUInfo,
|
||||||
|
getCpuFn: gpu.GetCPUInfo,
|
||||||
|
reschedDelay: 250 * time.Millisecond,
|
||||||
|
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
||||||
|
req.successCh <- &runnerRef{
|
||||||
|
llama: &mock,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.sched.Run(context.TODO())
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Modelfile: fmt.Sprintf(`FROM %s
|
||||||
|
TEMPLATE """
|
||||||
|
{{- if .System }}System: {{ .System }} {{ end }}
|
||||||
|
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
|
||||||
|
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
|
||||||
|
`, createBinFile(t, llm.KV{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"llama.block_count": uint32(1),
|
||||||
|
"llama.context_length": uint32(8192),
|
||||||
|
"llama.embedding_length": uint32(4096),
|
||||||
|
"llama.attention.head_count": uint32(32),
|
||||||
|
"llama.attention.head_count_kv": uint32(8),
|
||||||
|
"tokenizer.ggml.tokens": []string{""},
|
||||||
|
"tokenizer.ggml.scores": []float32{0},
|
||||||
|
"tokenizer.ggml.token_type": []int32{0},
|
||||||
|
}, []llm.Tensor{
|
||||||
|
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
})),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("missing body", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, nil)
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing model", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing capabilities generate", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "bert",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
|
||||||
|
"general.architecture": "bert",
|
||||||
|
"bert.pooling_type": uint32(0),
|
||||||
|
}, []llm.Tensor{})),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "bert",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing capabilities suffix", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Prompt: "def add(",
|
||||||
|
Suffix: " return c",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("load model", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var actual api.GenerateResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Model != "test" {
|
||||||
|
t.Errorf("expected model test, got %s", actual.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !actual.Done {
|
||||||
|
t.Errorf("expected done true, got false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.DoneReason != "load" {
|
||||||
|
t.Errorf("expected done reason load, got %s", actual.DoneReason)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var actual api.GenerateResponse
|
||||||
|
if err := json.NewDecoder(body).Decode(&actual); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Model != model {
|
||||||
|
t.Errorf("expected model test, got %s", actual.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !actual.Done {
|
||||||
|
t.Errorf("expected done false, got true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.DoneReason != "stop" {
|
||||||
|
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Response != content {
|
||||||
|
t.Errorf("expected response %s, got %s", content, actual.Response)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Context == nil {
|
||||||
|
t.Errorf("expected context not nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.PromptEvalCount == 0 {
|
||||||
|
t.Errorf("expected prompt eval count > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.PromptEvalDuration == 0 {
|
||||||
|
t.Errorf("expected prompt eval duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.EvalCount == 0 {
|
||||||
|
t.Errorf("expected eval count > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.EvalDuration == 0 {
|
||||||
|
t.Errorf("expected eval duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.LoadDuration == 0 {
|
||||||
|
t.Errorf("expected load duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.TotalDuration == 0 {
|
||||||
|
t.Errorf("expected total duration > 0, got 0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock.CompletionResponse.Content = "Hi!"
|
||||||
|
t.Run("prompt", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Prompt: "Hello!",
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkGenerateResponse(t, w.Body, "test", "Hi!")
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("prompt with model system", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Prompt: "Hello!",
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkGenerateResponse(t, w.Body, "test-system", "Hi!")
|
||||||
|
})
|
||||||
|
|
||||||
|
mock.CompletionResponse.Content = "Abra kadabra!"
|
||||||
|
t.Run("prompt with system", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Prompt: "Hello!",
|
||||||
|
System: "You can perform magic tricks.",
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("prompt with template", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Prompt: "Help me write tests.",
|
||||||
|
System: "You can perform magic tricks.",
|
||||||
|
Template: `{{- if .System }}{{ .System }} {{ end }}
|
||||||
|
{{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
|
||||||
|
{{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "test-suffix",
|
||||||
|
Modelfile: `FROM test
|
||||||
|
TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
|
||||||
|
{{- else }}{{ .Prompt }}
|
||||||
|
{{- end }}"""`,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("prompt with suffix", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-suffix",
|
||||||
|
Prompt: "def add(",
|
||||||
|
Suffix: " return c",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("prompt without suffix", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-suffix",
|
||||||
|
Prompt: "def add(",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("raw", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Prompt: "Help me write tests.",
|
||||||
|
Raw: true,
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
@@ -7,11 +7,14 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestList(t *testing.T) {
|
func TestList(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
|
||||||
|
@@ -7,6 +7,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
@@ -20,6 +21,7 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
|
"github.com/ollama/ollama/openai"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
@@ -105,6 +107,24 @@ func Test_Routes(t *testing.T) {
|
|||||||
assert.Empty(t, len(modelList.Models))
|
assert.Empty(t, len(modelList.Models))
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "openai empty list",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/v1/models",
|
||||||
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
assert.Equal(t, "application/json", contentType)
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var modelList openai.ListCompletion
|
||||||
|
err = json.Unmarshal(body, &modelList)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "list", modelList.Object)
|
||||||
|
assert.Empty(t, modelList.Data)
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "Tags Handler (yes tags)",
|
Name: "Tags Handler (yes tags)",
|
||||||
Method: http.MethodGet,
|
Method: http.MethodGet,
|
||||||
@@ -128,6 +148,25 @@ func Test_Routes(t *testing.T) {
|
|||||||
assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
|
assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "openai list models with tags",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/v1/models",
|
||||||
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
assert.Equal(t, "application/json", contentType)
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var modelList openai.ListCompletion
|
||||||
|
err = json.Unmarshal(body, &modelList)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, modelList.Data, 1)
|
||||||
|
assert.Equal(t, "test-model:latest", modelList.Data[0].Id)
|
||||||
|
assert.Equal(t, "library", modelList.Data[0].OwnedBy)
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "Create Model Handler",
|
Name: "Create Model Handler",
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
@@ -216,6 +255,95 @@ func Test_Routes(t *testing.T) {
|
|||||||
assert.InDelta(t, 0, showResp.ModelInfo["general.parameter_count"], 1e-9, "Parameter count should be 0")
|
assert.InDelta(t, 0, showResp.ModelInfo["general.parameter_count"], 1e-9, "Parameter count should be 0")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "openai retrieve model handler",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/v1/models/show-model",
|
||||||
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
assert.Equal(t, "application/json", contentType)
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var retrieveResp api.RetrieveModelResponse
|
||||||
|
err = json.Unmarshal(body, &retrieveResp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "show-model", retrieveResp.Id)
|
||||||
|
assert.Equal(t, "library", retrieveResp.OwnedBy)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Embed Handler Empty Input",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/embed",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
embedReq := api.EmbedRequest{
|
||||||
|
Model: "t-bone",
|
||||||
|
Input: "",
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(embedReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
if contentType != "application/json; charset=utf-8" {
|
||||||
|
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
|
||||||
|
}
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var embedResp api.EmbedResponse
|
||||||
|
err = json.Unmarshal(body, &embedResp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if embedResp.Model != "t-bone" {
|
||||||
|
t.Fatalf("expected model t-bone, got %s", embedResp.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if embedResp.Embeddings == nil {
|
||||||
|
t.Fatalf("expected embeddings to not be nil, got %v", embedResp.Embeddings)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(embedResp.Embeddings) != 0 {
|
||||||
|
t.Fatalf("expected embeddings to be empty, got %v", embedResp.Embeddings)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Embed Handler Invalid Input",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/embed",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
embedReq := api.EmbedRequest{
|
||||||
|
Model: "t-bone",
|
||||||
|
Input: 2,
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(embedReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
if contentType != "application/json; charset=utf-8" {
|
||||||
|
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
|
||||||
|
}
|
||||||
|
_, err := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected status code 400, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
@@ -364,3 +492,38 @@ func TestShow(t *testing.T) {
|
|||||||
t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
|
t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNormalize(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
input []float32
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{input: []float32{1}},
|
||||||
|
{input: []float32{0, 1, 2, 3}},
|
||||||
|
{input: []float32{0.1, 0.2, 0.3}},
|
||||||
|
{input: []float32{-0.1, 0.2, 0.3, -0.4}},
|
||||||
|
{input: []float32{0, 0, 0}},
|
||||||
|
}
|
||||||
|
|
||||||
|
isNormalized := func(vec []float32) (res bool) {
|
||||||
|
sum := 0.0
|
||||||
|
for _, v := range vec {
|
||||||
|
sum += float64(v * v)
|
||||||
|
}
|
||||||
|
if math.Abs(sum-1) > 1e-6 {
|
||||||
|
return sum == 0
|
||||||
|
} else {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
normalized := normalize(tc.input)
|
||||||
|
if !isNormalized(normalized) {
|
||||||
|
t.Errorf("Vector %v is not normalized", tc.input)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -24,7 +24,7 @@ type LlmRequest struct {
|
|||||||
model *Model
|
model *Model
|
||||||
opts api.Options
|
opts api.Options
|
||||||
origNumCtx int // Track the initial ctx request
|
origNumCtx int // Track the initial ctx request
|
||||||
sessionDuration time.Duration
|
sessionDuration *api.Duration
|
||||||
successCh chan *runnerRef
|
successCh chan *runnerRef
|
||||||
errCh chan error
|
errCh chan error
|
||||||
schedAttempts uint
|
schedAttempts uint
|
||||||
@@ -75,7 +75,7 @@ func InitScheduler(ctx context.Context) *Scheduler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// context must be canceled to decrement ref count and release the runner
|
// context must be canceled to decrement ref count and release the runner
|
||||||
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
|
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
|
||||||
if opts.NumCtx < 4 {
|
if opts.NumCtx < 4 {
|
||||||
opts.NumCtx = 4
|
opts.NumCtx = 4
|
||||||
}
|
}
|
||||||
@@ -133,10 +133,6 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||||||
numParallel = 1
|
numParallel = 1
|
||||||
slog.Warn("multimodal models don't support parallel requests yet")
|
slog.Warn("multimodal models don't support parallel requests yet")
|
||||||
}
|
}
|
||||||
// Keep NumCtx and numParallel in sync
|
|
||||||
if numParallel > 1 {
|
|
||||||
pending.opts.NumCtx = pending.origNumCtx * numParallel
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
var runnerToExpire *runnerRef
|
var runnerToExpire *runnerRef
|
||||||
@@ -197,9 +193,10 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||||||
// simplifying assumption of defaultParallel when in CPU mode
|
// simplifying assumption of defaultParallel when in CPU mode
|
||||||
if numParallel <= 0 {
|
if numParallel <= 0 {
|
||||||
numParallel = defaultParallel
|
numParallel = defaultParallel
|
||||||
pending.opts.NumCtx = pending.origNumCtx * numParallel
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pending.opts.NumCtx = pending.origNumCtx * numParallel
|
||||||
|
|
||||||
if loadedCount == 0 {
|
if loadedCount == 0 {
|
||||||
slog.Debug("cpu mode with first model, loading")
|
slog.Debug("cpu mode with first model, loading")
|
||||||
s.loadFn(pending, ggml, gpus, numParallel)
|
s.loadFn(pending, ggml, gpus, numParallel)
|
||||||
@@ -389,7 +386,9 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
|
|||||||
runner.expireTimer.Stop()
|
runner.expireTimer.Stop()
|
||||||
runner.expireTimer = nil
|
runner.expireTimer = nil
|
||||||
}
|
}
|
||||||
runner.sessionDuration = pending.sessionDuration
|
if pending.sessionDuration != nil {
|
||||||
|
runner.sessionDuration = pending.sessionDuration.Duration
|
||||||
|
}
|
||||||
pending.successCh <- runner
|
pending.successCh <- runner
|
||||||
go func() {
|
go func() {
|
||||||
<-pending.ctx.Done()
|
<-pending.ctx.Done()
|
||||||
@@ -402,6 +401,10 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
|
|||||||
if numParallel < 1 {
|
if numParallel < 1 {
|
||||||
numParallel = 1
|
numParallel = 1
|
||||||
}
|
}
|
||||||
|
sessionDuration := envconfig.KeepAlive
|
||||||
|
if req.sessionDuration != nil {
|
||||||
|
sessionDuration = req.sessionDuration.Duration
|
||||||
|
}
|
||||||
llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
|
llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// some older models are not compatible with newer versions of llama.cpp
|
// some older models are not compatible with newer versions of llama.cpp
|
||||||
@@ -419,7 +422,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
|
|||||||
modelPath: req.model.ModelPath,
|
modelPath: req.model.ModelPath,
|
||||||
llama: llama,
|
llama: llama,
|
||||||
Options: &req.opts,
|
Options: &req.opts,
|
||||||
sessionDuration: req.sessionDuration,
|
sessionDuration: sessionDuration,
|
||||||
gpus: gpus,
|
gpus: gpus,
|
||||||
estimatedVRAM: llama.EstimatedVRAM(),
|
estimatedVRAM: llama.EstimatedVRAM(),
|
||||||
estimatedTotal: llama.EstimatedTotal(),
|
estimatedTotal: llama.EstimatedTotal(),
|
||||||
|
@@ -44,7 +44,7 @@ func TestLoad(t *testing.T) {
|
|||||||
opts: api.DefaultOptions(),
|
opts: api.DefaultOptions(),
|
||||||
successCh: make(chan *runnerRef, 1),
|
successCh: make(chan *runnerRef, 1),
|
||||||
errCh: make(chan error, 1),
|
errCh: make(chan error, 1),
|
||||||
sessionDuration: 2,
|
sessionDuration: &api.Duration{Duration: 2 * time.Second},
|
||||||
}
|
}
|
||||||
// Fail to load model first
|
// Fail to load model first
|
||||||
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||||
@@ -142,7 +142,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
|
|||||||
ctx: scenario.ctx,
|
ctx: scenario.ctx,
|
||||||
model: model,
|
model: model,
|
||||||
opts: api.DefaultOptions(),
|
opts: api.DefaultOptions(),
|
||||||
sessionDuration: 5 * time.Millisecond,
|
sessionDuration: &api.Duration{Duration: 5 * time.Millisecond},
|
||||||
successCh: make(chan *runnerRef, 1),
|
successCh: make(chan *runnerRef, 1),
|
||||||
errCh: make(chan error, 1),
|
errCh: make(chan error, 1),
|
||||||
}
|
}
|
||||||
@@ -156,18 +156,18 @@ func TestRequests(t *testing.T) {
|
|||||||
|
|
||||||
// Same model, same request
|
// Same model, same request
|
||||||
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
|
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
|
||||||
scenario1a.req.sessionDuration = 5 * time.Millisecond
|
scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
|
||||||
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
|
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
|
||||||
scenario1b.req.model = scenario1a.req.model
|
scenario1b.req.model = scenario1a.req.model
|
||||||
scenario1b.ggml = scenario1a.ggml
|
scenario1b.ggml = scenario1a.ggml
|
||||||
scenario1b.req.sessionDuration = 0
|
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
|
||||||
|
|
||||||
// simple reload of same model
|
// simple reload of same model
|
||||||
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
|
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
|
||||||
tmpModel := *scenario1a.req.model
|
tmpModel := *scenario1a.req.model
|
||||||
scenario2a.req.model = &tmpModel
|
scenario2a.req.model = &tmpModel
|
||||||
scenario2a.ggml = scenario1a.ggml
|
scenario2a.ggml = scenario1a.ggml
|
||||||
scenario2a.req.sessionDuration = 5 * time.Millisecond
|
scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
|
||||||
|
|
||||||
// Multiple loaded models
|
// Multiple loaded models
|
||||||
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
|
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
|
||||||
@@ -199,6 +199,8 @@ func TestRequests(t *testing.T) {
|
|||||||
require.Equal(t, resp.llama, scenario1a.srv)
|
require.Equal(t, resp.llama, scenario1a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario1a.req.errCh)
|
require.Empty(t, scenario1a.req.errCh)
|
||||||
|
case err := <-scenario1a.req.errCh:
|
||||||
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
@@ -212,6 +214,8 @@ func TestRequests(t *testing.T) {
|
|||||||
require.Equal(t, resp.llama, scenario1a.srv)
|
require.Equal(t, resp.llama, scenario1a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario1b.req.errCh)
|
require.Empty(t, scenario1b.req.errCh)
|
||||||
|
case err := <-scenario1b.req.errCh:
|
||||||
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
@@ -230,6 +234,8 @@ func TestRequests(t *testing.T) {
|
|||||||
require.Equal(t, resp.llama, scenario2a.srv)
|
require.Equal(t, resp.llama, scenario2a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario2a.req.errCh)
|
require.Empty(t, scenario2a.req.errCh)
|
||||||
|
case err := <-scenario2a.req.errCh:
|
||||||
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
@@ -246,6 +252,8 @@ func TestRequests(t *testing.T) {
|
|||||||
require.Equal(t, resp.llama, scenario3a.srv)
|
require.Equal(t, resp.llama, scenario3a.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario3a.req.errCh)
|
require.Empty(t, scenario3a.req.errCh)
|
||||||
|
case err := <-scenario3a.req.errCh:
|
||||||
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
@@ -262,6 +270,8 @@ func TestRequests(t *testing.T) {
|
|||||||
require.Equal(t, resp.llama, scenario3b.srv)
|
require.Equal(t, resp.llama, scenario3b.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario3b.req.errCh)
|
require.Empty(t, scenario3b.req.errCh)
|
||||||
|
case err := <-scenario3b.req.errCh:
|
||||||
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
@@ -278,6 +288,8 @@ func TestRequests(t *testing.T) {
|
|||||||
require.Equal(t, resp.llama, scenario3c.srv)
|
require.Equal(t, resp.llama, scenario3c.srv)
|
||||||
require.Empty(t, s.pendingReqCh)
|
require.Empty(t, s.pendingReqCh)
|
||||||
require.Empty(t, scenario3c.req.errCh)
|
require.Empty(t, scenario3c.req.errCh)
|
||||||
|
case err := <-scenario3c.req.errCh:
|
||||||
|
t.Fatal(err.Error())
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
@@ -318,11 +330,11 @@ func TestGetRunner(t *testing.T) {
|
|||||||
defer done()
|
defer done()
|
||||||
|
|
||||||
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
|
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
|
||||||
scenario1a.req.sessionDuration = 0
|
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
|
||||||
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
|
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
|
||||||
scenario1b.req.sessionDuration = 0
|
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
|
||||||
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
|
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
|
||||||
scenario1c.req.sessionDuration = 0
|
scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
|
||||||
envconfig.MaxQueuedRequests = 1
|
envconfig.MaxQueuedRequests = 1
|
||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
s.getGpuFn = func() gpu.GpuInfoList {
|
s.getGpuFn = func() gpu.GpuInfoList {
|
||||||
@@ -402,7 +414,7 @@ func TestPrematureExpired(t *testing.T) {
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
time.Sleep(scenario1a.req.sessionDuration)
|
time.Sleep(scenario1a.req.sessionDuration.Duration)
|
||||||
scenario1a.ctxDone()
|
scenario1a.ctxDone()
|
||||||
time.Sleep(20 * time.Millisecond)
|
time.Sleep(20 * time.Millisecond)
|
||||||
require.LessOrEqual(t, len(s.finishedReqCh), 1)
|
require.LessOrEqual(t, len(s.finishedReqCh), 1)
|
||||||
@@ -423,7 +435,7 @@ func TestUseLoadedRunner(t *testing.T) {
|
|||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
opts: api.DefaultOptions(),
|
opts: api.DefaultOptions(),
|
||||||
successCh: make(chan *runnerRef, 1),
|
successCh: make(chan *runnerRef, 1),
|
||||||
sessionDuration: 2,
|
sessionDuration: &api.Duration{Duration: 2},
|
||||||
}
|
}
|
||||||
finished := make(chan *LlmRequest)
|
finished := make(chan *LlmRequest)
|
||||||
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
|
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
|
||||||
@@ -614,7 +626,7 @@ func TestAlreadyCanceled(t *testing.T) {
|
|||||||
dctx, done2 := context.WithCancel(ctx)
|
dctx, done2 := context.WithCancel(ctx)
|
||||||
done2()
|
done2()
|
||||||
scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
|
scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
|
||||||
scenario1a.req.sessionDuration = 0
|
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
|
||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
slog.Info("scenario1a")
|
slog.Info("scenario1a")
|
||||||
s.pendingReqCh <- scenario1a.req
|
s.pendingReqCh <- scenario1a.req
|
||||||
@@ -630,8 +642,8 @@ type mockLlm struct {
|
|||||||
pingResp error
|
pingResp error
|
||||||
waitResp error
|
waitResp error
|
||||||
completionResp error
|
completionResp error
|
||||||
embeddingResp []float64
|
embedResp [][]float32
|
||||||
embeddingRespErr error
|
embedRespErr error
|
||||||
tokenizeResp []int
|
tokenizeResp []int
|
||||||
tokenizeRespErr error
|
tokenizeRespErr error
|
||||||
detokenizeResp string
|
detokenizeResp string
|
||||||
@@ -648,8 +660,8 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
|
|||||||
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||||
return s.completionResp
|
return s.completionResp
|
||||||
}
|
}
|
||||||
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) {
|
||||||
return s.embeddingResp, s.embeddingRespErr
|
return s.embedResp, s.embedRespErr
|
||||||
}
|
}
|
||||||
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||||
return s.tokenizeResp, s.tokenizeRespErr
|
return s.tokenizeResp, s.tokenizeRespErr
|
||||||
|
67
server/testdata/tools/command-r-plus.gotmpl
vendored
Normal file
67
server/testdata/tools/command-r-plus.gotmpl
vendored
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
{{- if or .Tools .System }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>
|
||||||
|
{{- if .Tools }}# Safety Preamble
|
||||||
|
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
|
||||||
|
|
||||||
|
# System Preamble
|
||||||
|
## Basic Rules
|
||||||
|
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
|
||||||
|
|
||||||
|
{{ if .System }}# User Preamble
|
||||||
|
{{ .System }}
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
|
## Available Tools
|
||||||
|
Here is a list of tools that you have available to you:
|
||||||
|
{{- range .Tools }}
|
||||||
|
|
||||||
|
```python
|
||||||
|
def {{ .Function.Name }}(
|
||||||
|
{{- range $name, $property := .Function.Parameters.Properties }}{{ $name }}: {{ $property.Type }}, {{ end }}) -> List[Dict]:
|
||||||
|
"""{{ .Function.Description }}
|
||||||
|
|
||||||
|
{{- if .Function.Parameters.Properties }}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
{{- range $name, $property := .Function.Parameters.Properties }}
|
||||||
|
{{ $name }} ({{ $property.Type }}): {{ $property.Description }}
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
{{- end }}
|
||||||
|
{{- else if .System }}{{ .System }}
|
||||||
|
{{- end }}<|END_OF_TURN_TOKEN|>
|
||||||
|
{{- end }}
|
||||||
|
{{- range .Messages }}
|
||||||
|
{{- if eq .Role "system" }}
|
||||||
|
{{- continue }}
|
||||||
|
{{- end }}<|START_OF_TURN_TOKEN|>
|
||||||
|
{{- if eq .Role "user" }}<|USER_TOKEN|>{{ .Content }}
|
||||||
|
{{- else if eq .Role "assistant" }}<|CHATBOT_TOKEN|>
|
||||||
|
{{- if .Content }}{{ .Content }}
|
||||||
|
{{- else if .ToolCalls }}
|
||||||
|
Action: ```json
|
||||||
|
[
|
||||||
|
{{- range .ToolCalls }}
|
||||||
|
{
|
||||||
|
"tool_name": "{{ .Function.Name }}",
|
||||||
|
"parameters": {{ json .Function.Arguments }}
|
||||||
|
}
|
||||||
|
{{- end }}
|
||||||
|
]```
|
||||||
|
{{ continue }}
|
||||||
|
{{ end }}
|
||||||
|
{{- else if eq .Role "tool" }}<|SYSTEM_TOKEN|><results>
|
||||||
|
{{ .Content }}</results>
|
||||||
|
{{- end }}<|END_OF_TURN_TOKEN|>
|
||||||
|
{{- end }}
|
||||||
|
{{- if .Tools }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tool_name": title of the tool in the specification,
|
||||||
|
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
|
||||||
|
}
|
||||||
|
]```
|
||||||
|
{{- end }}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
39
server/testdata/tools/command-r-plus.out
vendored
Normal file
39
server/testdata/tools/command-r-plus.out
vendored
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble
|
||||||
|
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
|
||||||
|
|
||||||
|
# System Preamble
|
||||||
|
## Basic Rules
|
||||||
|
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
|
||||||
|
|
||||||
|
# User Preamble
|
||||||
|
You are a knowledgable assistant. You can answer questions and perform tasks.
|
||||||
|
|
||||||
|
## Available Tools
|
||||||
|
Here is a list of tools that you have available to you:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_current_weather(format: string, location: string, ) -> List[Dict]:
|
||||||
|
"""Get the current weather
|
||||||
|
|
||||||
|
Args:
|
||||||
|
format (string): The temperature unit to use. Infer this from the users location.
|
||||||
|
location (string): The city and state, e.g. San Francisco, CA
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in Paris?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
||||||
|
Action: ```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tool_name": "get_current_weather",
|
||||||
|
"parameters": {"format":"celsius","location":"Paris, France"}
|
||||||
|
}
|
||||||
|
]```
|
||||||
|
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><results>
|
||||||
|
22</results><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>The current temperature in Paris, France is 22 degrees Celsius.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in San Francisco and Toronto?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tool_name": title of the tool in the specification,
|
||||||
|
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
|
||||||
|
}
|
||||||
|
]```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
|
31
server/testdata/tools/firefunction.gotmpl
vendored
Normal file
31
server/testdata/tools/firefunction.gotmpl
vendored
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
|
||||||
|
{{- if .System }}
|
||||||
|
{{ .System }}
|
||||||
|
{{- end }}
|
||||||
|
In addition to plain text responses, you can chose to call one or more of the provided functions.
|
||||||
|
|
||||||
|
Use the following rule to decide when to call a function:
|
||||||
|
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
|
||||||
|
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
|
||||||
|
|
||||||
|
If you decide to call functions:
|
||||||
|
* prefix function calls with functools marker (no closing marker required)
|
||||||
|
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
|
||||||
|
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
|
||||||
|
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
|
||||||
|
* make sure you pick the right functions that match the user intent
|
||||||
|
|
||||||
|
Available functions as JSON spec:
|
||||||
|
{{- if .Tools }}
|
||||||
|
{{ json .Tools }}
|
||||||
|
{{- end }}<|eot_id|>
|
||||||
|
{{- end }}
|
||||||
|
{{- range .Messages }}<|start_header_id|>
|
||||||
|
{{- if or (eq .Role "user") (eq .Role "assistant") (eq .Role "tool") }}{{ .Role }}
|
||||||
|
{{- end }}<|end_header_id|>
|
||||||
|
{{- if .Content }}{{ .Content }}
|
||||||
|
{{- else if .ToolCalls }} functools[
|
||||||
|
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}{{ "}" }}
|
||||||
|
{{- end }}]
|
||||||
|
{{- end }}<|eot_id|>
|
||||||
|
{{- end }}<|start_header_id|>assistant<|end_header_id|>
|
17
server/testdata/tools/firefunction.out
vendored
Normal file
17
server/testdata/tools/firefunction.out
vendored
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
<|start_header_id|>system<|end_header_id|>
|
||||||
|
You are a knowledgable assistant. You can answer questions and perform tasks.
|
||||||
|
In addition to plain text responses, you can chose to call one or more of the provided functions.
|
||||||
|
|
||||||
|
Use the following rule to decide when to call a function:
|
||||||
|
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
|
||||||
|
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
|
||||||
|
|
||||||
|
If you decide to call functions:
|
||||||
|
* prefix function calls with functools marker (no closing marker required)
|
||||||
|
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
|
||||||
|
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
|
||||||
|
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
|
||||||
|
* make sure you pick the right functions that match the user intent
|
||||||
|
|
||||||
|
Available functions as JSON spec:
|
||||||
|
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]<|eot_id|><|start_header_id|><|end_header_id|>You are a knowledgable assistant. You can answer questions and perform tasks.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> functools[{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]<|eot_id|><|start_header_id|>tool<|end_header_id|>22<|eot_id|><|start_header_id|>assistant<|end_header_id|>The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
39
server/testdata/tools/messages.json
vendored
Normal file
39
server/testdata/tools/messages.json
vendored
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a knowledgable assistant. You can answer questions and perform tasks."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather like today in Paris?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "89a1e453-0bce-4de3-a456-c54bed09c520",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"arguments": {
|
||||||
|
"location": "Paris, France",
|
||||||
|
"format": "celsius"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "89a1e453-0bce-4de3-a456-c54bed09c520",
|
||||||
|
"content": "22"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "The current temperature in Paris, France is 22 degrees Celsius."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather like today in San Francisco and Toronto?"
|
||||||
|
}
|
||||||
|
]
|
15
server/testdata/tools/mistral.gotmpl
vendored
Normal file
15
server/testdata/tools/mistral.gotmpl
vendored
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
{{- range $index, $_ := .Messages }}
|
||||||
|
{{- if eq .Role "user" }}
|
||||||
|
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
|
||||||
|
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
|
||||||
|
|
||||||
|
{{ end }}{{ .Content }}[/INST]
|
||||||
|
{{- else if eq .Role "assistant" }}
|
||||||
|
{{- if .Content }} {{ .Content }}</s>
|
||||||
|
{{- else if .ToolCalls }}[TOOL_CALLS] [
|
||||||
|
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
|
||||||
|
{{- end }}]</s>
|
||||||
|
{{- end }}
|
||||||
|
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
3
server/testdata/tools/mistral.out
vendored
Normal file
3
server/testdata/tools/mistral.out
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[INST] What's the weather like today in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]</s>[TOOL_RESULTS] {"content": 22}[/TOOL_RESULTS] The current temperature in Paris, France is 22 degrees Celsius.</s>[AVAILABLE_TOOLS] [{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}][/AVAILABLE_TOOLS][INST] You are a knowledgable assistant. You can answer questions and perform tasks.
|
||||||
|
|
||||||
|
What's the weather like today in San Francisco and Toronto?[/INST]
|
30
server/testdata/tools/tools.json
vendored
Normal file
30
server/testdata/tools/tools.json
vendored
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA"
|
||||||
|
},
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"celsius",
|
||||||
|
"fahrenheit"
|
||||||
|
],
|
||||||
|
"description": "The temperature unit to use. Infer this from the users location."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"location",
|
||||||
|
"format"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
@@ -4,4 +4,5 @@
|
|||||||
{{ .Prompt }}
|
{{ .Prompt }}
|
||||||
|
|
||||||
{{ end }}### Response:
|
{{ end }}### Response:
|
||||||
{{ .Response }}
|
{{ .Response }}
|
||||||
|
|
@@ -3,4 +3,4 @@
|
|||||||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||||
{{ .Prompt }}<|im_end|>
|
{{ .Prompt }}<|im_end|>
|
||||||
{{ end }}<|im_start|>assistant
|
{{ end }}<|im_start|>assistant
|
||||||
{{ .Response }}<|im_end|>
|
{{ .Response }}<|im_end|>
|
@@ -2,4 +2,5 @@
|
|||||||
|
|
||||||
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
|
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
|
||||||
|
|
||||||
{{ end }}Assistant: <|begin_of_text|>{{ .Response }}
|
{{ end }}Assistant: {{ .Response }}
|
||||||
|
|
10
template/codellama-70b-instruct.gotmpl
Normal file
10
template/codellama-70b-instruct.gotmpl
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
{{ if .System }}Source: system
|
||||||
|
|
||||||
|
{{ .System }} <step> {{ end }}Source: user
|
||||||
|
|
||||||
|
{{ .Prompt }} <step> Source: assistant
|
||||||
|
{{- if not .Response }}
|
||||||
|
Destination: user
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
|
{{ .Response }} <step>
|
5
template/falcon-instruct.gotmpl
Normal file
5
template/falcon-instruct.gotmpl
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{{ if .System }}System: {{ .System }}
|
||||||
|
{{ end }}{{ if .Prompt }}User:
|
||||||
|
{{ .Prompt }}
|
||||||
|
{{ end }}Falcon:
|
||||||
|
{{ .Response }}
|
5
template/gemma-instruct.gotmpl
Normal file
5
template/gemma-instruct.gotmpl
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
<start_of_turn>user
|
||||||
|
{{ if .System }}{{ .System }}
|
||||||
|
{{ end }}{{ .Prompt }}<end_of_turn>
|
||||||
|
<start_of_turn>model
|
||||||
|
{{ .Response }}<end_of_turn>
|
@@ -1,9 +1,9 @@
|
|||||||
{{ if .System }}
|
{{ if .System }}System:
|
||||||
System:
|
|
||||||
{{ .System }}
|
{{ .System }}
|
||||||
|
|
||||||
{{ end }}{{ if .Prompt }}Question:
|
{{ end }}{{ if .Prompt }}Question:
|
||||||
{{ .Prompt }}
|
{{ .Prompt }}
|
||||||
|
|
||||||
{{ end }}Answer:
|
{{ end }}Answer:
|
||||||
{{ .Response }}
|
{{ .Response }}
|
||||||
|
|
6
template/llama2-chat.gotmpl
Normal file
6
template/llama2-chat.gotmpl
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
[INST] <<SYS>>
|
||||||
|
{{- if .System }}
|
||||||
|
{{ .System }}
|
||||||
|
{{ end }}<</SYS>>
|
||||||
|
|
||||||
|
{{ .Prompt }} [/INST] {{ .Response }}</s><s>
|
@@ -4,4 +4,5 @@
|
|||||||
{{ .Prompt }}
|
{{ .Prompt }}
|
||||||
|
|
||||||
{{ end }}@@ Response
|
{{ end }}@@ Response
|
||||||
{{ .Response }}
|
{{ .Response }}
|
||||||
|
|
3
template/mistral-instruct.gotmpl
Normal file
3
template/mistral-instruct.gotmpl
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[INST] {{ if .System }}{{ .System }}
|
||||||
|
|
||||||
|
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}</s>
|
1
template/openchat.gotmpl
Normal file
1
template/openchat.gotmpl
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{{ if .System }}GPT4 Correct System: {{ .System }}<|end_of_turn|>{{ end }}GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>
|
@@ -3,4 +3,4 @@
|
|||||||
{{ end }}{{ if .Prompt }}<|user|>
|
{{ end }}{{ if .Prompt }}<|user|>
|
||||||
{{ .Prompt }}<|end|>
|
{{ .Prompt }}<|end|>
|
||||||
{{ end }}<|assistant|>
|
{{ end }}<|assistant|>
|
||||||
{{ .Response }}<|end|>
|
{{ .Response }}<|end|>
|
@@ -5,4 +5,5 @@
|
|||||||
{{ .Prompt }}
|
{{ .Prompt }}
|
||||||
|
|
||||||
{{ end }}### Assistant:
|
{{ end }}### Assistant:
|
||||||
{{ .Response }}
|
{{ .Response }}</s>
|
||||||
|
|
@@ -3,7 +3,6 @@
|
|||||||
{{ end }}{{ if .Prompt }}### Instruction
|
{{ end }}{{ if .Prompt }}### Instruction
|
||||||
{{ .Prompt }}
|
{{ .Prompt }}
|
||||||
|
|
||||||
|
|
||||||
{{ end }}### Response
|
{{ end }}### Response
|
||||||
{{ .Response }}<|endoftext|>
|
{{ .Response }}<|endoftext|>
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user