From 16abd181a91da9df62501503fd3b61449f849904 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Thu, 30 Jan 2025 13:48:24 -0800 Subject: [PATCH] remove context shifting with max tokens and update docs --- docs/openai.md | 30 ++++++++++++++++++++++++++++++ openai/openai.go | 15 ++++----------- openai/openai_test.go | 23 ++--------------------- 3 files changed, 36 insertions(+), 32 deletions(-) diff --git a/docs/openai.md b/docs/openai.md index b0f9b353c..0bed2db6d 100644 --- a/docs/openai.md +++ b/docs/openai.md @@ -94,6 +94,20 @@ except Exception as e: print(f"Error: {e}") ``` +#### Experimental + +- `num_ctx` parameter can be used to set the context window for the model +- OpenAI Python SDK does not support setting context window size, however this can be set for Ollama through the `extra_body` parameter + +- The recommended way to control this is through the [Ollama Python SDK](https://github.com/ollama/ollama-python) with the `options` parameter +```py +completion = client.beta.chat.completions.create( + model="llama3.1:8b", + messages=[{"role": "user", "content": "Say this is a test"}], + extra_body={"num_ctx": 4096}, +) +``` + ### OpenAI JavaScript library ```javascript @@ -142,6 +156,21 @@ const embedding = await openai.embeddings.create({ }) ``` +#### Experimental + +- `num_ctx` parameter can be used to set the context window for the model +- OpenAI JS SDK does not support setting context window size, however this can be set for Ollama by passing `num_ctx` directly with a `@ts-expect-error` as an undocumented parameter in the [OpenAI JS SDK](https://github.com/openai/openai-node?tab=readme-ov-file#making-customundocumented-requests) + +- The recommended way to control this is through the [Ollama JS SDK](https://github.com/ollama/ollama-js) with the `options` parameter +```js +const chatCompletion = await openai.chat.completions.create({ + messages: [{ role: 'user', content: 'Say this is a test' }], + model: 'llama3.2', + // @ts-expect-error num_ctx is not officially supported + num_ctx: 4096, +}) +``` + ### `curl` ``` shell @@ -213,6 +242,7 @@ curl http://localhost:11434/v1/embeddings \ - [x] Chat completions - [x] Streaming - [x] JSON mode +- [x] Structured outputs - [x] Reproducible outputs - [x] Vision - [x] Tools diff --git a/openai/openai.go b/openai/openai.go index 03cd74879..dc9f1e795 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -477,24 +477,17 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { options["stop"] = stops } + if r.NumCtx != nil { + options["num_ctx"] = *r.NumCtx + } + // Deprecated: MaxTokens is deprecated, use MaxCompletionTokens instead if r.MaxTokens != nil { r.MaxCompletionTokens = r.MaxTokens } - if r.NumCtx != nil { - options["num_ctx"] = *r.NumCtx - } - - DEFAULT_NUM_CTX := 2048 - // set num_ctx to max_completion_tokens if it's greater than num_ctx if r.MaxCompletionTokens != nil { options["num_predict"] = *r.MaxCompletionTokens - if r.NumCtx != nil && *r.MaxCompletionTokens > *r.NumCtx { - options["num_ctx"] = *r.MaxCompletionTokens - } else if *r.MaxCompletionTokens > DEFAULT_NUM_CTX { - options["num_ctx"] = *r.MaxCompletionTokens - } } if r.Temperature != nil { diff --git a/openai/openai_test.go b/openai/openai_test.go index 32f680505..00be4e426 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -81,7 +81,7 @@ func TestChatMiddleware(t *testing.T) { {"role": "user", "content": "Hello"} ], "stream": true, - "max_completion_tokens": 999, + "max_tokens": 999, "seed": 123, "stop": ["\n", "stop"], "temperature": 3.0, @@ -333,7 +333,7 @@ func TestChatMiddleware(t *testing.T) { }, }, { - name: "chat handler with max_completion_tokens < num_ctx", + name: "chat handler with max_completion_tokens", body: `{ "model": "test-model", "messages": [{"role": "user", "content": "Hello"}], @@ -350,25 +350,6 @@ func TestChatMiddleware(t *testing.T) { Stream: &False, }, }, - { - name: "chat handler with max_completion_tokens > num_ctx", - body: `{ - "model": "test-model", - "messages": [{"role": "user", "content": "Hello"}], - "max_completion_tokens": 4096 - }`, - req: api.ChatRequest{ - Model: "test-model", - Messages: []api.Message{{Role: "user", Content: "Hello"}}, - Options: map[string]any{ - "num_predict": 4096.0, // float because JSON doesn't distinguish between float and int - "num_ctx": 4096.0, - "temperature": 1.0, - "top_p": 1.0, - }, - Stream: &False, - }, - }, { name: "chat handler error forwarding", body: `{