Compare commits

...

1 Commits

Author SHA1 Message Date
Devon Rifkin
bc8abf7917 WIP thinking API support
- Allows specifying whether thinking mode should be on or not
- Templates get passed a new option so, e.g., qwen3's template can put
  `/think` or `/no_think` in the system prompt depending on the value of
  the setting
- Add parsing for thinking blocks in both streaming/non-streaming mode
- Update the CLI to make use of these changes
2025-05-12 17:23:41 -07:00
18 changed files with 1030 additions and 25 deletions

View File

@ -83,6 +83,12 @@ type GenerateRequest struct {
// Options lists model-specific options. For example, temperature can be // Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it. // set through this field, if the model supports it.
Options map[string]any `json:"options"` Options map[string]any `json:"options"`
// Think controls whether thinking/reasoning models will think before
// responding. Needs to be a pointer so we can distinguish between false
// (request that thinking _not_ be used) and unset (use the old behavior
// before this option was introduced)
Think *bool `json:"think,omitempty"`
} }
// ChatRequest describes a request sent by [Client.Chat]. // ChatRequest describes a request sent by [Client.Chat].
@ -108,6 +114,10 @@ type ChatRequest struct {
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]any `json:"options"` Options map[string]any `json:"options"`
// Think controls whether thinking/reasoning models will think before
// responding
Think *bool `json:"think,omitempty"`
} }
type Tools []Tool type Tools []Tool
@ -126,8 +136,11 @@ func (t Tool) String() string {
// role ("system", "user", or "assistant"), the content and an optional list // role ("system", "user", or "assistant"), the content and an optional list
// of images. // of images.
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`
// Thinking contains the text that was inside thinking tags in the
// original model output when ChatRequest.Think is enabled.
Thinking string `json:"thinking,omitempty"`
Images []ImageData `json:"images,omitempty"` Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
} }
@ -275,6 +288,8 @@ type Options struct {
MirostatTau float32 `json:"mirostat_tau,omitempty"` MirostatTau float32 `json:"mirostat_tau,omitempty"`
MirostatEta float32 `json:"mirostat_eta,omitempty"` MirostatEta float32 `json:"mirostat_eta,omitempty"`
Stop []string `json:"stop,omitempty"` Stop []string `json:"stop,omitempty"`
Think bool `json:"think,omitempty"`
} }
// Runner options which must be set when the model is loaded into memory // Runner options which must be set when the model is loaded into memory

View File

@ -372,3 +372,50 @@ func TestPropertyType_MarshalJSON(t *testing.T) {
}) })
} }
} }
func TestThinking_UnmarshalJSON(t *testing.T) {
trueVal := true
falseVal := false
tests := []struct {
name string
input string
expectedThinking *bool
expectedError bool
}{
{
name: "true",
input: `{ "think": true }`,
expectedThinking: &trueVal,
},
{
name: "false",
input: `{ "think": false }`,
expectedThinking: &falseVal,
},
{
name: "unset",
input: `{ }`,
expectedThinking: nil,
},
{
name: "invalid",
input: `{ "think": "true" }`,
expectedThinking: nil,
expectedError: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var req GenerateRequest
err := json.Unmarshal([]byte(test.input), &req)
if test.expectedError {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, test.expectedThinking, req.Think)
}
})
}
}

View File

@ -38,12 +38,32 @@ import (
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/runner" "github.com/ollama/ollama/runner"
"github.com/ollama/ollama/server" "github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
// warnMissingThinking emits a warning if the model does not advertise thinking
// support and opts.Thinking is set. Failures to query the capability are
// ignored so this does not impact regular usage.
func warnMissingThinking(ctx context.Context, client *api.Client, name string) {
if name == "" {
return
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
if err != nil {
return
}
for _, cap := range resp.Capabilities {
if cap == model.CapabilityThinking {
return
}
}
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
}
var errModelfileNotFound = errors.New("specified Modelfile wasn't found") var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
func getModelfileName(cmd *cobra.Command) (string, error) { func getModelfileName(cmd *cobra.Command) (string, error) {
@ -240,9 +260,18 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
return err return err
} }
think := opts.Think
if think == nil {
falseVal := false
think = &falseVal
}
req := &api.GenerateRequest{ req := &api.GenerateRequest{
Model: opts.Model, Model: opts.Model,
KeepAlive: opts.KeepAlive, KeepAlive: opts.KeepAlive,
// pass Think here so we fail before getting to the chat prompt if the model doesn't support it
Think: opts.Think,
} }
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil }) return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
@ -277,6 +306,17 @@ func RunHandler(cmd *cobra.Command, args []string) error {
} }
opts.Format = format opts.Format = format
thinkFlag := cmd.Flags().Lookup("think")
if thinkFlag.Changed {
think, err := cmd.Flags().GetBool("think")
if err != nil {
return err
}
opts.Think = &think
} else {
opts.Think = nil
}
keepAlive, err := cmd.Flags().GetString("keepalive") keepAlive, err := cmd.Flags().GetString("keepalive")
if err != nil { if err != nil {
return err return err
@ -361,6 +401,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
if err := loadOrUnloadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
return err return err
} }
warnMissingThinking(cmd.Context(), client, opts.Model)
for _, msg := range info.Messages { for _, msg := range info.Messages {
switch msg.Role { switch msg.Role {
@ -876,6 +917,7 @@ type runOptions struct {
Options map[string]any Options map[string]any
MultiModal bool MultiModal bool
KeepAlive *api.Duration KeepAlive *api.Duration
Think *bool
} }
type displayResponseState struct { type displayResponseState struct {
@ -958,6 +1000,8 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
var latest api.ChatResponse var latest api.ChatResponse
var fullResponse strings.Builder var fullResponse strings.Builder
var role string var role string
var thinkTagOpened bool = false
var thinkTagClosed bool = false
fn := func(response api.ChatResponse) error { fn := func(response api.ChatResponse) error {
p.StopAndClear() p.StopAndClear()
@ -965,7 +1009,23 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
latest = response latest = response
role = response.Message.Role role = response.Message.Role
if response.Message.Thinking != "" {
if !thinkTagOpened {
fmt.Print(readline.ColorGrey + readline.ColorBold + "<think>" + readline.ColorDefault + readline.ColorGrey)
thinkTagOpened = true
}
displayResponse(response.Message.Thinking, opts.WordWrap, state)
}
content := response.Message.Content content := response.Message.Content
if !thinkTagClosed && thinkTagOpened && content != "" {
fmt.Print(readline.ColorGrey + readline.ColorBold + "</think>" + readline.ColorDefault)
thinkTagClosed = true
}
// purposefully not putting thinking blocks in the response, which would
// only be needed if we later added tool calling to the cli (they get
// filtered out anyway since current models don't expect them unless you're
// about to finish some tool calls)
fullResponse.WriteString(content) fullResponse.WriteString(content)
displayResponse(content, opts.WordWrap, state) displayResponse(content, opts.WordWrap, state)
@ -982,6 +1042,11 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
Messages: opts.Messages, Messages: opts.Messages,
Format: json.RawMessage(opts.Format), Format: json.RawMessage(opts.Format),
Options: opts.Options, Options: opts.Options,
Think: opts.Think,
}
if opts.Think != nil {
warnMissingThinking(cmd.Context(), client, opts.Model)
} }
if opts.KeepAlive != nil { if opts.KeepAlive != nil {
@ -1075,6 +1140,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
System: opts.System, System: opts.System,
Options: opts.Options, Options: opts.Options,
KeepAlive: opts.KeepAlive, KeepAlive: opts.KeepAlive,
Think: opts.Think,
} }
if err := client.Generate(ctx, &request, fn); err != nil { if err := client.Generate(ctx, &request, fn); err != nil {
@ -1290,6 +1356,7 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("insecure", false, "Use an insecure registry") runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically") runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
runCmd.Flags().String("format", "", "Response format (e.g. json)") runCmd.Flags().String("format", "", "Response format (e.g. json)")
runCmd.Flags().Bool("think", false, "Turn on thinking mode for supported models")
stopCmd := &cobra.Command{ stopCmd := &cobra.Command{
Use: "stop MODEL", Use: "stop MODEL",

View File

@ -62,6 +62,8 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting") fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats") fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats") fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
fmt.Fprintln(os.Stderr, " /set think Enable thinking")
fmt.Fprintln(os.Stderr, " /set nothink Disable thinking")
fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "")
} }
@ -260,6 +262,17 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
return err return err
} }
fmt.Println("Set 'quiet' mode.") fmt.Println("Set 'quiet' mode.")
case "think":
think := true
opts.Think = &think
if client, err := api.ClientFromEnvironment(); err == nil {
warnMissingThinking(cmd.Context(), client, opts.Model)
}
fmt.Println("Set 'think' mode.")
case "nothink":
think := false
opts.Think = &think
fmt.Println("Set 'nothink' mode.")
case "format": case "format":
if len(args) < 3 || args[2] != "json" { if len(args) < 3 || args[2] != "json" {
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'") fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")

30
cmd/templatefmt/main.go Normal file
View File

@ -0,0 +1,30 @@
package main
import (
"flag"
"fmt"
"io/ioutil"
"log"
"os"
"github.com/ollama/ollama/template"
)
func main() {
flag.Parse()
if flag.NArg() != 1 {
fmt.Fprintf(os.Stderr, "usage: %s <template.gotmpl>\n", os.Args[0])
os.Exit(2)
}
path := flag.Arg(0)
data, err := ioutil.ReadFile(path)
if err != nil {
log.Fatal(err)
}
out, err := template.Format(string(data))
if err != nil {
log.Fatal(err)
}
fmt.Print(out)
}

64
cmd/warn_thinking_test.go Normal file
View File

@ -0,0 +1,64 @@
package cmd
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
// Test that a warning is printed when thinking is requested but not supported.
func TestWarnMissingThinking(t *testing.T) {
cases := []struct {
capabilities []model.Capability
expectWarn bool
}{
{capabilities: []model.Capability{model.CapabilityThinking}, expectWarn: false},
{capabilities: []model.Capability{}, expectWarn: true},
}
for _, tc := range cases {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/show" || r.Method != http.MethodPost {
t.Fatalf("unexpected request to %s %s", r.URL.Path, r.Method)
}
var req api.ShowRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
resp := api.ShowResponse{Capabilities: tc.capabilities}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("encode response: %v", err)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
client, err := api.ClientFromEnvironment()
if err != nil {
t.Fatal(err)
}
oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
warnMissingThinking(context.Background(), client, "m")
w.Close()
os.Stderr = oldStderr
out, _ := io.ReadAll(r)
warned := strings.Contains(string(out), "warning:")
if tc.expectWarn && !warned {
t.Errorf("expected warning, got none")
}
if !tc.expectWarn && warned {
t.Errorf("did not expect warning, got: %s", string(out))
}
}
}

View File

@ -61,6 +61,8 @@ const (
ColorGrey = Esc + "[38;5;245m" ColorGrey = Esc + "[38;5;245m"
ColorDefault = Esc + "[0m" ColorDefault = Esc + "[0m"
ColorBold = Esc + "[1m"
StartBracketedPaste = Esc + "[?2004h" StartBracketedPaste = Esc + "[?2004h"
EndBracketedPaste = Esc + "[?2004l" EndBracketedPaste = Esc + "[?2004l"
) )

View File

@ -37,6 +37,7 @@ var (
errCapabilityInsert = errors.New("insert") errCapabilityInsert = errors.New("insert")
errCapabilityVision = errors.New("vision") errCapabilityVision = errors.New("vision")
errCapabilityEmbedding = errors.New("embedding") errCapabilityEmbedding = errors.New("embedding")
errCapabilityThinking = errors.New("thinking")
errInsecureProtocol = errors.New("insecure protocol http") errInsecureProtocol = errors.New("insecure protocol http")
) )
@ -106,6 +107,12 @@ func (m *Model) Capabilities() []model.Capability {
capabilities = append(capabilities, model.CapabilityInsert) capabilities = append(capabilities, model.CapabilityInsert)
} }
// Check for thinking capability
openingTag, closingTag := inferThinkingTags(m.Template.Template)
if openingTag != "" && closingTag != "" {
capabilities = append(capabilities, model.CapabilityThinking)
}
return capabilities return capabilities
} }
@ -122,6 +129,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
model.CapabilityInsert: errCapabilityInsert, model.CapabilityInsert: errCapabilityInsert,
model.CapabilityVision: errCapabilityVision, model.CapabilityVision: errCapabilityVision,
model.CapabilityEmbedding: errCapabilityEmbedding, model.CapabilityEmbedding: errCapabilityEmbedding,
model.CapabilityThinking: errCapabilityThinking,
} }
for _, cap := range want { for _, cap := range want {

View File

@ -22,7 +22,7 @@ var errTooManyImages = errors.New("vision model only supports a single image per
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn. // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages // latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) { func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *bool) (prompt string, images []llm.ImageData, _ error) {
var system []api.Message var system []api.Message
isMllama := checkMllamaModelFamily(m) isMllama := checkMllamaModelFamily(m)
@ -56,8 +56,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
} }
} }
thinkVal := false
if think != nil {
thinkVal = *think
}
var b bytes.Buffer var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil { if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil {
return "", nil, err return "", nil, err
} }
@ -142,7 +146,11 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
// truncate any messages that do not fit into the context window // truncate any messages that do not fit into the context window
var b bytes.Buffer var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil { thinkVal := false
if think != nil {
thinkVal = *think
}
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil {
return "", nil, err return "", nil, err
} }

View File

@ -318,7 +318,8 @@ func TestChatPrompt(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
model := tt.model model := tt.model
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil) think := false
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &think)
if tt.error == nil && err != nil { if tt.error == nil && err != nil {
t.Fatal(err) t.Fatal(err)
} else if tt.error != nil && err != tt.error { } else if tt.error != nil && err != tt.error {

View File

@ -18,7 +18,6 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"regexp"
"slices" "slices"
"strings" "strings"
"syscall" "syscall"
@ -181,6 +180,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
if req.Suffix != "" { if req.Suffix != "" {
caps = append(caps, model.CapabilityInsert) caps = append(caps, model.CapabilityInsert)
} }
if req.Think != nil {
// note that the capability is still required even if `Thinking` is false
// because turning off thinking requires the model to support it (e.g.,
// older qwen3 templates don't know how to turn off thinking)
caps = append(caps, model.CapabilityThinking)
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) { if errors.Is(err, errCapabilityCompletion) {
@ -1475,6 +1480,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
if len(req.Tools) > 0 { if len(req.Tools) > 0 {
caps = append(caps, model.CapabilityTools) caps = append(caps, model.CapabilityTools)
} }
if req.Think != nil {
caps = append(caps, model.CapabilityThinking)
}
name := model.ParseName(req.Model) name := model.ParseName(req.Model)
if !name.IsValid() { if !name.IsValid() {
@ -1515,7 +1523,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
msgs = filterThinkTags(msgs, m) msgs = filterThinkTags(msgs, m)
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools, req.Think)
if err != nil { if err != nil {
slog.Error("chat prompt error", "error", err) slog.Error("chat prompt error", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@ -1524,6 +1532,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
slog.Debug("chat request", "images", len(images), "prompt", prompt) slog.Debug("chat request", "images", len(images), "prompt", prompt)
var thinkingState *thinkingParser
openingTag, closingTag := inferThinkingTags(m.Template.Template)
if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" {
thinkingState = &thinkingParser{
openingTag: openingTag,
closingTag: closingTag,
}
}
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
@ -1548,6 +1565,20 @@ func (s *Server) ChatHandler(c *gin.Context) {
}, },
} }
if thinkingState != nil {
if openingTag == "" || closingTag == "" {
// TODO(drifkin): put warning here
} else {
thinkingContent, remainingContent := thinkingState.addContent(res.Message.Content)
if thinkingContent == "" && remainingContent == "" && !r.Done {
// need to accumulate more to decide what to send
return
}
res.Message.Content = remainingContent
res.Message.Thinking = thinkingContent
}
}
if r.Done { if r.Done {
res.DoneReason = r.DoneReason.String() res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart) res.TotalDuration = time.Since(checkpointStart)
@ -1565,7 +1596,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
// Streaming tool calls: // Streaming tool calls:
// If tools are recognized, use a flag to track the sending of a tool downstream // If tools are recognized, use a flag to track the sending of a tool downstream
// This ensures that content is cleared from the message on the last chunk sent // This ensures that content is cleared from the message on the last chunk sent
sb.WriteString(r.Content) sb.WriteString(res.Message.Content)
if toolCalls, ok := m.parseToolCalls(sb.String()); ok { if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
res.Message.ToolCalls = toolCalls res.Message.ToolCalls = toolCalls
for i := range toolCalls { for i := range toolCalls {
@ -1613,9 +1644,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
resp.Message.Content = sb.String() resp.Message.Content = sb.String()
if req.Think != nil && *req.Think {
resp.Message.Thinking, resp.Message.Content = extractThinking(resp.Message.Content)
}
if len(req.Tools) > 0 { if len(req.Tools) > 0 {
if toolCalls, ok := m.parseToolCalls(sb.String()); ok { if toolCalls, ok := m.parseToolCalls(resp.Message.Content); ok {
resp.Message.ToolCalls = toolCalls resp.Message.ToolCalls = toolCalls
resp.Message.Content = "" resp.Message.Content = ""
} }
@ -1643,7 +1677,16 @@ func handleScheduleError(c *gin.Context, name string, err error) {
} }
} }
var thinkTagRegexp = regexp.MustCompile(`<think>(?s).*?</think>(\n)*`) // returns (thinkingContent, content)
func extractThinking(text string) (string, string) {
thinking := thinkingParser{
openingTag: "<think>",
closingTag: "</think>",
}
thinkingContent, content := thinking.addContent(text)
return thinkingContent, content
}
func filterThinkTags(msgs []api.Message, m *Model) []api.Message { func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" { if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
@ -1656,7 +1699,9 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
for i, msg := range msgs { for i, msg := range msgs {
if msg.Role == "assistant" && i < finalUserIndex { if msg.Role == "assistant" && i < finalUserIndex {
msgs[i].Content = thinkTagRegexp.ReplaceAllString(msg.Content, "") thinkingContent, content := extractThinking(msg.Content)
msg.Content = content
msg.Thinking = thinkingContent
} }
} }
} }

View File

@ -143,6 +143,25 @@ func TestGenerateChat(t *testing.T) {
} }
}) })
t.Run("missing thinking capability", func(t *testing.T) {
think := true
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Think: &think,
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support thinking"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing model", func(t *testing.T) { t.Run("missing model", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{}) w := createRequest(t, s.ChatHandler, api.ChatRequest{})
if w.Code != http.StatusBadRequest { if w.Code != http.StatusBadRequest {

256
server/thinking.go Normal file
View File

@ -0,0 +1,256 @@
package server
import (
"strings"
"text/template"
"text/template/parse"
"unicode"
)
type thinkingParseState int
const (
thinkingParseState_LookingForOpening thinkingParseState = iota
thinkingParseState_Thinking
thinkingParseState_ThinkingDone
)
func (s thinkingParseState) String() string {
switch s {
case thinkingParseState_LookingForOpening:
return "LookingForOpening"
case thinkingParseState_Thinking:
return "Thinking"
case thinkingParseState_ThinkingDone:
return "ThinkingDone"
default:
return "Unknown"
}
}
type thinkingParser struct {
state thinkingParseState
openingTag string
closingTag string
acc strings.Builder
}
// returns the thinking content and the normal content that should be
// immediately sent to the user. It will internally buffer if it needs to see
// more content to disambiguate
func (s *thinkingParser) addContent(content string) (string, string) {
s.acc.WriteString(content)
var thinkingAcc, remainingAcc strings.Builder
var thinking, remaining string
keepLooping := true
// we loop because we might pass through multiple parsing states in a single
// call to addContent, and we want to make sure callers don't have to wait for
// data that's already unambiguous
for keepLooping {
thinking, remaining, keepLooping = eat(s)
thinkingAcc.WriteString(thinking)
remainingAcc.WriteString(remaining)
}
return thinkingAcc.String(), remainingAcc.String()
}
// the additional bool return is true iff we should continue eating
func eat(s *thinkingParser) (string, string, bool) {
switch s.state {
case thinkingParseState_LookingForOpening:
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
if strings.HasPrefix(trimmed, s.openingTag) {
after := strings.Join(strings.Split(trimmed, s.openingTag)[1:], s.openingTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
// after might contain more than just thinking tokens, so we continue
// parsing instead of returning it as thinking tokens here
s.acc.Reset()
s.acc.WriteString(after)
s.state = thinkingParseState_Thinking
return "", "", true
} else if strings.HasPrefix(s.openingTag, trimmed) {
// partial opening seen, so let's keep accumulating
return "", "", false
} else if trimmed == "" {
// saw whitespace only, so let's keep accumulating
return "", "", false
} else {
// didn't see an opening tag, but we have content, so thinking was skipped
s.state = thinkingParseState_ThinkingDone
// note that we use the original content, not the trimmed one because we
// don't want to eat any whitespace in the real content if there were no
// thinking tags
return "", s.acc.String(), false
}
case thinkingParseState_Thinking:
acc := s.acc.String()
if strings.Contains(acc, s.closingTag) {
split := strings.Split(acc, s.closingTag)
thinking := split[0]
remaining := strings.Join(split[1:], s.closingTag)
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
s.acc.Reset()
s.state = thinkingParseState_ThinkingDone
return thinking, remaining, false
} else if overlapLen := overlap(acc, s.closingTag); overlapLen > 0 {
thinking := acc[:len(acc)-overlapLen]
remaining := acc[len(acc)-overlapLen:]
s.acc.Reset()
// keep track of the candidate closing tag. We have to buffer it until it
// becomes disambiguated
s.acc.WriteString(remaining)
return thinking, "", false
} else {
// purely just thinking tokens, so we can return them
s.acc.Reset()
return acc, "", false
}
case thinkingParseState_ThinkingDone:
acc := s.acc.String()
s.acc.Reset()
return "", acc, false
default:
panic("unknown state")
}
}
// longest overlap between suffix of s and prefix of delim
func overlap(s, delim string) int {
max := min(len(delim), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, delim[:i]) {
return i
}
}
return 0
}
func templateVisit(n parse.Node, enterFn func(parse.Node) bool, exitFn func(parse.Node)) {
if n == nil {
return
}
shouldContinue := enterFn(n)
if !shouldContinue {
return
}
switch x := n.(type) {
case *parse.ListNode:
for _, c := range x.Nodes {
templateVisit(c, enterFn, exitFn)
}
case *parse.BranchNode:
if x.Pipe != nil {
templateVisit(x.Pipe, enterFn, exitFn)
}
if x.List != nil {
templateVisit(x.List, enterFn, exitFn)
}
if x.ElseList != nil {
templateVisit(x.ElseList, enterFn, exitFn)
}
case *parse.ActionNode:
templateVisit(x.Pipe, enterFn, exitFn)
case *parse.WithNode:
templateVisit(&x.BranchNode, enterFn, exitFn)
case *parse.RangeNode:
templateVisit(&x.BranchNode, enterFn, exitFn)
case *parse.IfNode:
templateVisit(&x.BranchNode, enterFn, exitFn)
case *parse.TemplateNode:
templateVisit(x.Pipe, enterFn, exitFn)
case *parse.PipeNode:
for _, c := range x.Cmds {
templateVisit(c, enterFn, exitFn)
}
case *parse.CommandNode:
for _, a := range x.Args {
templateVisit(a, enterFn, exitFn)
}
// text, field, number, etc. are leaves nothing to recurse into
}
if exitFn != nil {
exitFn(n)
}
}
// We use a heuristic to infer the tags that surround thinking traces:
// We look for a range node that iterates over "Messages" and then look for a
// reference to "Thinking" like `{{.Thinking}}`. We then go up to the nearest
// ListNode and take the first and last TextNodes as the opening and closing
// tags.
func inferThinkingTags(t *template.Template) (string, string) {
ancestors := []parse.Node{}
openingTag := ""
closingTag := ""
enterFn := func(n parse.Node) bool {
ancestors = append(ancestors, n)
switch x := n.(type) {
case *parse.FieldNode:
if len(x.Ident) > 0 && x.Ident[0] == "Thinking" {
var mostRecentRange *parse.RangeNode
for i := len(ancestors) - 1; i >= 0; i-- {
if r, ok := ancestors[i].(*parse.RangeNode); ok {
mostRecentRange = r
break
}
}
if mostRecentRange == nil || !rangeUsesField(mostRecentRange, "Messages") {
return true
}
// TODO(drifkin): to be more robust, check that it's in the action
// part, not the `if`'s pipeline part. We do match on the nearest list
// that starts and ends with text nodes, which makes this not strictly
// necessary for our heuristic
// go up to the nearest ancestor that is a *parse.ListNode
for i := len(ancestors) - 1; i >= 0; i-- {
if l, ok := ancestors[i].(*parse.ListNode); ok {
firstNode := l.Nodes[0]
if t, ok := firstNode.(*parse.TextNode); ok {
openingTag = strings.TrimSpace(t.String())
}
lastNode := l.Nodes[len(l.Nodes)-1]
if t, ok := lastNode.(*parse.TextNode); ok {
closingTag = strings.TrimSpace(t.String())
}
break
}
}
}
}
return true
}
exitFn := func(n parse.Node) {
ancestors = ancestors[:len(ancestors)-1]
}
templateVisit(t.Root, enterFn, exitFn)
return openingTag, closingTag
}
// checks to see if the given field name is present in the pipeline of the given range node
func rangeUsesField(rangeNode *parse.RangeNode, field string) bool {
found := false
enterFn := func(n parse.Node) bool {
switch x := n.(type) {
case *parse.FieldNode:
if x.Ident[0] == field {
found = true
}
}
return true
}
templateVisit(rangeNode.BranchNode.Pipe, enterFn, nil)
return found
}

286
server/thinking_test.go Normal file
View File

@ -0,0 +1,286 @@
package server
import (
"testing"
"text/template"
)
func TestExtractThinking(t *testing.T) {
tests := []struct {
in, wantContent, wantThink string
}{
{
in: "<think> internal </think> world",
wantThink: "internal ",
wantContent: "world",
},
{
in: "<think>a</think><think>b</think>c",
wantThink: "a",
wantContent: "<think>b</think>c",
},
{
in: "no think",
wantThink: "",
wantContent: "no think",
},
}
for i, tt := range tests {
gotThinking, gotContent := extractThinking(tt.in)
if gotContent != tt.wantContent || gotThinking != tt.wantThink {
t.Errorf("case %d: got (%q,%q), want (%q,%q)", i, gotThinking, gotContent, tt.wantThink, tt.wantContent)
}
}
}
func TestThinkingStreaming(t *testing.T) {
type step struct {
input string
wantThinking string
wantContent string
wantStateAfter thinkingParseState
}
cases := []struct {
desc string
skip bool
steps []step
}{
{
desc: "content without a thinking tag",
steps: []step{
{
input: " abc",
wantThinking: "",
wantContent: " abc",
wantStateAfter: thinkingParseState_ThinkingDone,
},
},
},
{
desc: "content before a thinking tag nerfs the thinking tag",
steps: []step{
{
input: " abc <think>def</think> ghi",
wantThinking: "",
wantContent: " abc <think>def</think> ghi",
wantStateAfter: thinkingParseState_ThinkingDone,
},
},
},
{
desc: "building up a thinking tag partially",
// skip: true,
steps: []step{
{
input: " <th",
wantThinking: "",
wantContent: "",
wantStateAfter: thinkingParseState_LookingForOpening,
},
{
input: "in",
wantThinking: "",
wantContent: "",
wantStateAfter: thinkingParseState_LookingForOpening,
},
{
input: "k>a",
wantThinking: "a",
wantContent: "",
wantStateAfter: thinkingParseState_Thinking,
},
},
},
{
desc: "partial closing tag",
steps: []step{
{
input: "<think>abc</th",
wantThinking: "abc",
wantContent: "",
wantStateAfter: thinkingParseState_Thinking,
},
{
input: "ink>def",
wantThinking: "",
wantContent: "def",
wantStateAfter: thinkingParseState_ThinkingDone,
},
},
},
{
desc: "partial closing tag fakeout",
steps: []step{
{
input: "<think>abc</th",
wantThinking: "abc",
wantContent: "",
wantStateAfter: thinkingParseState_Thinking,
},
{
input: "ing>def",
wantThinking: "</thing>def",
wantContent: "",
wantStateAfter: thinkingParseState_Thinking,
},
{
input: "ghi</thi",
wantThinking: "ghi",
wantContent: "",
wantStateAfter: thinkingParseState_Thinking,
},
{
input: "nk>jkl",
wantThinking: "",
wantContent: "jkl",
wantStateAfter: thinkingParseState_ThinkingDone,
},
},
},
}
for _, c := range cases {
parser := thinkingParser{
openingTag: "<think>",
closingTag: "</think>",
}
if c.skip {
continue
}
for i, step := range c.steps {
thinking, content := parser.addContent(step.input)
if content != step.wantContent || thinking != step.wantThinking {
t.Errorf("case %q (step %d): got (%q,%q), want (%q,%q)", c.desc, i, content, thinking, step.wantContent, step.wantThinking)
}
if parser.state != step.wantStateAfter {
t.Errorf("case %q (step %d): got state %s, want %s", c.desc, i, parser.state.String(), step.wantStateAfter.String())
}
}
}
}
func TestInferThinkingTags(t *testing.T) {
cases := []struct {
desc string
tmplString string
wantOpeningTag string
wantClosingTag string
}{
{
desc: "basic",
tmplString: `
{{ if .Thinking}}
/think
{{ end }}
{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
{{ if and $last .Thinking }}
<think>{{ .Thinking }}</think>
{{ end }}
{{ end }}
`,
wantOpeningTag: "<think>",
wantClosingTag: "</think>",
},
{
desc: "doubly nested range",
tmplString: `
{{ if .Thinking}}
/think
{{ end }}
{{- range $i, $_ := .Messages }}
{{- range $j, $_ := .NotMessages }}
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
{{ if and $last .Thinking }}
<think>{{ .Thinking }}</think>
{{ end }}
{{ end }}
{{ end }}
`,
wantOpeningTag: "",
wantClosingTag: "",
},
{
desc: "whitespace is trimmed",
tmplString: `
{{ if .Thinking}}
/think
{{ end }}
{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
{{ if and $last .Thinking }}
Some text before {{ .Thinking }} Some text after
{{ end }}
{{ end }}
`,
wantOpeningTag: "Some text before",
wantClosingTag: "Some text after",
},
{
desc: "qwen3",
tmplString: `
{{- if or .System .Tools .Thinking }}<|im_start|>system
{{- if .System }}
{{ .System }}
{{- end }}
{{- if .Tools }}
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{{- range .Tools }}
{"type": "function", "function": {{ .Function }}}
{{- end }}
</tools>
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
{{- end }}
{{- if .Thinking }}
/think
{{- else }}
/no_think
{{- end }}<|im_end|>
{{ end }}
{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
{{- if eq .Role "user" }}<|im_start|>user
{{ .Content }}<|im_end|>
{{ else if eq .Role "assistant" }}<|im_start|>assistant
{{ if and $last .Thinking }}
<think>{{ .Thinking }}</think>
{{ end }}
{{ if .Content }}{{ .Content }}
{{- else if .ToolCalls }}<tool_call>
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{ end }}</tool_call>
{{- end }}{{ if not $last }}<|im_end|>
{{ end }}
{{- else if eq .Role "tool" }}<|im_start|>user
<tool_response>
{{ .Content }}
</tool_response><|im_end|>
{{ end }}
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
{{ end }}
{{- end }}
`,
wantOpeningTag: "<think>",
wantClosingTag: "</think>",
},
}
for _, c := range cases {
tmpl := template.Must(template.New("test").Parse(c.tmplString))
openingTag, closingTag := inferThinkingTags(tmpl)
if openingTag != c.wantOpeningTag || closingTag != c.wantClosingTag {
t.Errorf("case %q: got (%q,%q), want (%q,%q)", c.desc, openingTag, closingTag, c.wantOpeningTag, c.wantClosingTag)
}
}
}

101
template/pretty.go Normal file
View File

@ -0,0 +1,101 @@
package template
import (
"strings"
texttmpl "text/template"
"text/template/parse"
)
// Format returns a human-readable representation of the template.
// The formatting indents nested sections such as if/else blocks.
func Format(src string) (string, error) {
tmpl, err := texttmpl.New("pretty").Parse(src)
if err != nil {
return "", err
}
var sb strings.Builder
printNodes(tmpl.Tree.Root, 0, &sb)
return sb.String(), nil
}
func indent(sb *strings.Builder, level int) {
for i := 0; i < level; i++ {
sb.WriteString(" ")
}
}
func printNodes(list *parse.ListNode, level int, sb *strings.Builder) {
if list == nil {
return
}
for _, n := range list.Nodes {
printNode(n, level, sb)
}
}
func printNode(n parse.Node, level int, sb *strings.Builder) {
switch n := n.(type) {
case *parse.TextNode:
text := strings.TrimSpace(string(n.Text))
if text == "" {
return
}
indent(sb, level)
sb.WriteString(text)
sb.WriteByte('\n')
case *parse.ActionNode:
indent(sb, level)
// sb.WriteString("ACTION {{ ")
sb.WriteString(n.String())
// sb.WriteString(" }}\n")
sb.WriteByte('\n')
case *parse.IfNode:
indent(sb, level)
sb.WriteString("{{ if ")
sb.WriteString(n.Pipe.String())
sb.WriteString(" }}\n")
printNodes(n.List, level+1, sb)
if n.ElseList != nil {
indent(sb, level)
sb.WriteString("{{ else }}\n")
printNodes(n.ElseList, level+1, sb)
}
indent(sb, level)
sb.WriteString("{{ end }}\n")
case *parse.RangeNode:
indent(sb, level)
sb.WriteString("{{ range ")
sb.WriteString(n.Pipe.String())
sb.WriteString(" }}\n")
printNodes(n.List, level+1, sb)
if n.ElseList != nil {
indent(sb, level)
sb.WriteString("{{ else }}\n")
printNodes(n.ElseList, level+1, sb)
}
indent(sb, level)
sb.WriteString("{{ end }}\n")
case *parse.WithNode:
indent(sb, level)
sb.WriteString("{{ with ")
sb.WriteString(n.Pipe.String())
sb.WriteString(" }}\n")
printNodes(n.List, level+1, sb)
if n.ElseList != nil {
indent(sb, level)
sb.WriteString("{{ else }}\n")
printNodes(n.ElseList, level+1, sb)
}
indent(sb, level)
sb.WriteString("{{ end }}\n")
case *parse.TemplateNode:
indent(sb, level)
sb.WriteString("{{ template ")
sb.WriteString(n.Name)
sb.WriteString(" }}\n")
default:
indent(sb, level)
sb.WriteString(n.String())
sb.WriteByte('\n')
}
}

30
template/pretty_test.go Normal file
View File

@ -0,0 +1,30 @@
package template
import (
"strings"
"testing"
)
func TestFormatIndentation(t *testing.T) {
tmpl := "{{ if .Cond }}A{{ else }}B{{ end }}"
out, err := Format(tmpl)
if err != nil {
t.Fatal(err)
}
expectedLines := []string{
"{{ if .Cond }}",
" A",
"{{ else }}",
" B",
"{{ end }}",
}
got := strings.Split(strings.TrimSpace(out), "\n")
if len(got) != len(expectedLines) {
t.Fatalf("expected %d lines, got %d: %q", len(expectedLines), len(got), out)
}
for i, line := range expectedLines {
if strings.TrimSpace(got[i]) != strings.TrimSpace(line) {
t.Errorf("line %d = %q, want %q", i, got[i], line)
}
}
}

View File

@ -167,6 +167,10 @@ type Values struct {
api.Tools api.Tools
Prompt string Prompt string
Suffix string Suffix string
Think bool
// whether or not the user explicitly set the thinking flag (vs. it being
// implicitly false). Templates can't see whether `Think` is nil
IsThinkSet bool
// forceLegacy is a flag used to test compatibility with legacy templates // forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy bool forceLegacy bool
@ -222,16 +226,20 @@ func (t *Template) Execute(w io.Writer, v Values) error {
system, messages := collate(v.Messages) system, messages := collate(v.Messages)
if v.Prompt != "" && v.Suffix != "" { if v.Prompt != "" && v.Suffix != "" {
return t.Template.Execute(w, map[string]any{ return t.Template.Execute(w, map[string]any{
"Prompt": v.Prompt, "Prompt": v.Prompt,
"Suffix": v.Suffix, "Suffix": v.Suffix,
"Response": "", "Response": "",
"Think": v.Think,
"IsThinkSet": v.IsThinkSet,
}) })
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { } else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{ return t.Template.Execute(w, map[string]any{
"System": system, "System": system,
"Messages": messages, "Messages": messages,
"Tools": v.Tools, "Tools": v.Tools,
"Response": "", "Response": "",
"Think": v.Think,
"IsThinkSet": v.IsThinkSet,
}) })
} }
@ -241,9 +249,11 @@ func (t *Template) Execute(w io.Writer, v Values) error {
for _, m := range messages { for _, m := range messages {
execute := func() error { execute := func() error {
if err := t.Template.Execute(&b, map[string]any{ if err := t.Template.Execute(&b, map[string]any{
"System": system, "System": system,
"Prompt": prompt, "Prompt": prompt,
"Response": response, "Response": response,
"Think": v.Think,
"IsThinkSet": v.IsThinkSet,
}); err != nil { }); err != nil {
return err return err
} }
@ -286,9 +296,11 @@ func (t *Template) Execute(w io.Writer, v Values) error {
tree := parse.Tree{Root: nodes.(*parse.ListNode)} tree := parse.Tree{Root: nodes.(*parse.ListNode)}
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{ if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
"System": system, "System": system,
"Prompt": prompt, "Prompt": prompt,
"Response": response, "Response": response,
"Think": v.Think,
"IsThinkSet": v.IsThinkSet,
}); err != nil { }); err != nil {
return err return err
} }

View File

@ -8,6 +8,7 @@ const (
CapabilityInsert = Capability("insert") CapabilityInsert = Capability("insert")
CapabilityVision = Capability("vision") CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding") CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
) )
func (c Capability) String() string { func (c Capability) String() string {