Compare commits

...

10 Commits

Author SHA1 Message Date
ParthSareen
b4de2e9189 change name to context_length 2025-02-07 11:50:38 -08:00
ParthSareen
61a5254115 context_window and addressing comments 2025-02-05 11:26:55 -08:00
ParthSareen
53d2cf37d2 update docs 2025-02-04 15:17:16 -08:00
ParthSareen
75f88e7aac Update docs 2025-02-04 10:47:32 -08:00
ParthSareen
4982089c84 Fix formatting 2025-01-30 13:53:24 -08:00
Parth Sareen
8c231b0826
Update openai/openai.go
Co-authored-by: Michael Yang <mxyng@pm.me>
2025-01-30 13:50:25 -08:00
ParthSareen
16abd181a9 remove context shifting with max tokens and update docs 2025-01-30 13:48:24 -08:00
ParthSareen
5c2f35d846 Add tests 2025-01-30 13:16:15 -08:00
ParthSareen
6de3227841 Cleanup api 2025-01-30 13:15:57 -08:00
ParthSareen
35e97db03b set num_ctx through extra body 2025-01-29 13:13:11 -08:00
3 changed files with 111 additions and 43 deletions

View File

@ -204,6 +204,45 @@ curl http://localhost:11434/v1/embeddings \
}'
```
## Extra arguments
### Setting context length
- `context_length` parameter can be used to set the context length for the model
#### OpenAI python library
- OpenAI python library does not support setting context length, however this can be set for Ollama through the `extra_body` parameter
```py
completion = client.chat.completions.create(
model="llama3.1:8b",
messages=[{"role": "user", "content": "Say this is a test"}],
extra_body={"context_length": 4096},
)
```
#### OpenAI JavaScript library
- OpenAI JavaScript library does not support setting context length, however this can be set for Ollama by passing `context_length` directly with a `@ts-expect-error` as an undocumented parameter in the OpenAI JavaScript library. [See documentation here](https://github.com/openai/openai-node?tab=readme-ov-file#making-customundocumented-requests)
```ts
const chatCompletion = await openai.chat.completions.create({
messages: [{ role: 'user', content: 'Say this is a test' }],
model: 'llama3.2',
// @ts-expect-error context_length is an additional parameter
context_length: 4096,
})
```
#### `curl`
```shell
curl http://localhost:11434/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama3.2",
"messages": [{"role": "user", "content": "Say this is a test"}],
"context_length": 4096
}'
```
## Endpoints
### `/v1/chat/completions`
@ -213,6 +252,7 @@ curl http://localhost:11434/v1/embeddings \
- [x] Chat completions
- [x] Streaming
- [x] JSON mode
- [x] Structured outputs
- [x] Reproducible outputs
- [x] Vision
- [x] Tools
@ -339,27 +379,3 @@ curl http://localhost:11434/v1/chat/completions \
}'
```
### Setting the context size
The OpenAI API does not have a way of setting the context size for a model. If you need to change the context size, create a `Modelfile` which looks like:
```modelfile
FROM <some model>
PARAMETER num_ctx <context size>
```
Use the `ollama create mymodel` command to create a new model with the updated context size. Call the API with the updated model name:
```shell
curl http://localhost:11434/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "mymodel",
"messages": [
{
"role": "user",
"content": "Hello!"
}
]
}'
```

View File

@ -80,10 +80,12 @@ type StreamOptions struct {
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
StreamOptions *StreamOptions `json:"stream_options"`
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
StreamOptions *StreamOptions `json:"stream_options"`
MaxCompletionTokens *int `json:"max_completion_tokens"`
// Deprecated: Use [ChatCompletionRequest.MaxCompletionTokens]
MaxTokens *int `json:"max_tokens"`
Seed *int `json:"seed"`
Stop any `json:"stop"`
@ -93,6 +95,7 @@ type ChatCompletionRequest struct {
TopP *float64 `json:"top_p"`
ResponseFormat *ResponseFormat `json:"response_format"`
Tools []api.Tool `json:"tools"`
ContextLength *int `json:"context_length"`
}
type ChatCompletion struct {
@ -475,8 +478,17 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
options["stop"] = stops
}
if r.ContextLength != nil {
options["num_ctx"] = *r.ContextLength
}
// Deprecated: MaxTokens is deprecated, use MaxCompletionTokens instead
if r.MaxTokens != nil {
options["num_predict"] = *r.MaxTokens
r.MaxCompletionTokens = r.MaxTokens
}
if r.MaxCompletionTokens != nil {
options["num_predict"] = *r.MaxCompletionTokens
}
if r.Temperature != nil {
@ -962,6 +974,7 @@ func ChatMiddleware() gin.HandlerFunc {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
slog.Info("num_ctx", "num_ctx", chatReq.Options["num_ctx"])
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))

View File

@ -7,7 +7,6 @@ import (
"io"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
@ -315,6 +314,42 @@ func TestChatMiddleware(t *testing.T) {
Stream: &True,
},
},
{
name: "chat handler with context_length",
body: `{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"context_length": 4096
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Options: map[string]any{
"num_ctx": 4096.0, // float because JSON doesn't distinguish between float and int
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with max_completion_tokens",
body: `{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"max_completion_tokens": 2
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Options: map[string]any{
"num_predict": 2.0, // float because JSON doesn't distinguish between float and int
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler error forwarding",
body: `{
@ -359,7 +394,7 @@ func TestChatMiddleware(t *testing.T) {
return
}
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
t.Fatalf("requests did not match: %+v", diff)
t.Fatalf("requests did not match (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tc.err, errResp); diff != "" {
t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
@ -493,12 +528,14 @@ func TestCompletionsMiddleware(t *testing.T) {
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
if capturedRequest != nil {
if diff := cmp.Diff(tc.req, *capturedRequest); diff != "" {
t.Fatalf("requests did not match (-want +got):\n%s", diff)
}
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
if diff := cmp.Diff(tc.err, errResp); diff != "" {
t.Fatalf("errors did not match (-want +got):\n%s", diff)
}
capturedRequest = nil
@ -577,12 +614,14 @@ func TestEmbeddingsMiddleware(t *testing.T) {
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
if capturedRequest != nil {
if diff := cmp.Diff(tc.req, *capturedRequest); diff != "" {
t.Fatalf("requests did not match (-want +got):\n%s", diff)
}
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
if diff := cmp.Diff(tc.err, errResp); diff != "" {
t.Fatalf("errors did not match (-want +got):\n%s", diff)
}
capturedRequest = nil
@ -656,8 +695,8 @@ func TestListMiddleware(t *testing.T) {
t.Fatalf("failed to unmarshal actual response: %v", err)
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
if diff := cmp.Diff(expected, actual); diff != "" {
t.Errorf("responses did not match (-want +got):\n%s", diff)
}
}
}
@ -722,8 +761,8 @@ func TestRetrieveMiddleware(t *testing.T) {
t.Fatalf("failed to unmarshal actual response: %v", err)
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
if diff := cmp.Diff(expected, actual); diff != "" {
t.Errorf("responses did not match (-want +got):\n%s", diff)
}
}
}