WIP thinking API support

- Allows specifying whether thinking mode should be on or not
- Templates get passed a new option so, e.g., qwen3's template can put
  `/think` or `/no_think` in the system prompt depending on the value of
  the setting
- Add parsing for thinking blocks in both streaming/non-streaming mode
- Update the CLI to make use of these changes
This commit is contained in:
Devon Rifkin 2025-05-12 17:23:41 -07:00
parent a7835c6716
commit bc8abf7917
18 changed files with 1030 additions and 25 deletions

View File

@ -83,6 +83,12 @@ type GenerateRequest struct {
// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]any `json:"options"`
// Think controls whether thinking/reasoning models will think before
// responding. Needs to be a pointer so we can distinguish between false
// (request that thinking _not_ be used) and unset (use the old behavior
// before this option was introduced)
Think *bool `json:"think,omitempty"`
}
// ChatRequest describes a request sent by [Client.Chat].
@ -108,6 +114,10 @@ type ChatRequest struct {
// Options lists model-specific options.
Options map[string]any `json:"options"`
// Think controls whether thinking/reasoning models will think before
// responding
Think *bool `json:"think,omitempty"`
}
type Tools []Tool
@ -126,8 +136,11 @@ func (t Tool) String() string {
// role ("system", "user", or "assistant"), the content and an optional list
// of images.
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
Role string `json:"role"`
Content string `json:"content"`
// Thinking contains the text that was inside thinking tags in the
// original model output when ChatRequest.Think is enabled.
Thinking string `json:"thinking,omitempty"`
Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
@ -275,6 +288,8 @@ type Options struct {
MirostatTau float32 `json:"mirostat_tau,omitempty"`
MirostatEta float32 `json:"mirostat_eta,omitempty"`
Stop []string `json:"stop,omitempty"`
Think bool `json:"think,omitempty"`
}
// Runner options which must be set when the model is loaded into memory

View File

@ -372,3 +372,50 @@ func TestPropertyType_MarshalJSON(t *testing.T) {
})
}
}
func TestThinking_UnmarshalJSON(t *testing.T) {
trueVal := true
falseVal := false
tests := []struct {
name string
input string
expectedThinking *bool
expectedError bool
}{
{
name: "true",
input: `{ "think": true }`,
expectedThinking: &trueVal,
},
{
name: "false",
input: `{ "think": false }`,
expectedThinking: &falseVal,
},
{
name: "unset",
input: `{ }`,
expectedThinking: nil,
},
{
name: "invalid",
input: `{ "think": "true" }`,
expectedThinking: nil,
expectedError: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var req GenerateRequest
err := json.Unmarshal([]byte(test.input), &req)
if test.expectedError {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, test.expectedThinking, req.Think)
}
})
}
}

View File

@ -38,12 +38,32 @@ import (
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/runner"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
// warnMissingThinking emits a warning if the model does not advertise thinking
// support and opts.Thinking is set. Failures to query the capability are
// ignored so this does not impact regular usage.
func warnMissingThinking(ctx context.Context, client *api.Client, name string) {
if name == "" {
return
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
if err != nil {
return
}
for _, cap := range resp.Capabilities {
if cap == model.CapabilityThinking {
return
}
}
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
}
var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
func getModelfileName(cmd *cobra.Command) (string, error) {
@ -240,9 +260,18 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
return err
}
think := opts.Think
if think == nil {
falseVal := false
think = &falseVal
}
req := &api.GenerateRequest{
Model: opts.Model,
KeepAlive: opts.KeepAlive,
// pass Think here so we fail before getting to the chat prompt if the model doesn't support it
Think: opts.Think,
}
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
@ -277,6 +306,17 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
opts.Format = format
thinkFlag := cmd.Flags().Lookup("think")
if thinkFlag.Changed {
think, err := cmd.Flags().GetBool("think")
if err != nil {
return err
}
opts.Think = &think
} else {
opts.Think = nil
}
keepAlive, err := cmd.Flags().GetString("keepalive")
if err != nil {
return err
@ -361,6 +401,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
if err := loadOrUnloadModel(cmd, &opts); err != nil {
return err
}
warnMissingThinking(cmd.Context(), client, opts.Model)
for _, msg := range info.Messages {
switch msg.Role {
@ -876,6 +917,7 @@ type runOptions struct {
Options map[string]any
MultiModal bool
KeepAlive *api.Duration
Think *bool
}
type displayResponseState struct {
@ -958,6 +1000,8 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
var latest api.ChatResponse
var fullResponse strings.Builder
var role string
var thinkTagOpened bool = false
var thinkTagClosed bool = false
fn := func(response api.ChatResponse) error {
p.StopAndClear()
@ -965,7 +1009,23 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
latest = response
role = response.Message.Role
if response.Message.Thinking != "" {
if !thinkTagOpened {
fmt.Print(readline.ColorGrey + readline.ColorBold + "<think>" + readline.ColorDefault + readline.ColorGrey)
thinkTagOpened = true
}
displayResponse(response.Message.Thinking, opts.WordWrap, state)
}
content := response.Message.Content
if !thinkTagClosed && thinkTagOpened && content != "" {
fmt.Print(readline.ColorGrey + readline.ColorBold + "</think>" + readline.ColorDefault)
thinkTagClosed = true
}
// purposefully not putting thinking blocks in the response, which would
// only be needed if we later added tool calling to the cli (they get
// filtered out anyway since current models don't expect them unless you're
// about to finish some tool calls)
fullResponse.WriteString(content)
displayResponse(content, opts.WordWrap, state)
@ -982,6 +1042,11 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
Messages: opts.Messages,
Format: json.RawMessage(opts.Format),
Options: opts.Options,
Think: opts.Think,
}
if opts.Think != nil {
warnMissingThinking(cmd.Context(), client, opts.Model)
}
if opts.KeepAlive != nil {
@ -1075,6 +1140,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
System: opts.System,
Options: opts.Options,
KeepAlive: opts.KeepAlive,
Think: opts.Think,
}
if err := client.Generate(ctx, &request, fn); err != nil {
@ -1290,6 +1356,7 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
runCmd.Flags().String("format", "", "Response format (e.g. json)")
runCmd.Flags().Bool("think", false, "Turn on thinking mode for supported models")
stopCmd := &cobra.Command{
Use: "stop MODEL",

View File

@ -62,6 +62,8 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
fmt.Fprintln(os.Stderr, " /set think Enable thinking")
fmt.Fprintln(os.Stderr, " /set nothink Disable thinking")
fmt.Fprintln(os.Stderr, "")
}
@ -260,6 +262,17 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
return err
}
fmt.Println("Set 'quiet' mode.")
case "think":
think := true
opts.Think = &think
if client, err := api.ClientFromEnvironment(); err == nil {
warnMissingThinking(cmd.Context(), client, opts.Model)
}
fmt.Println("Set 'think' mode.")
case "nothink":
think := false
opts.Think = &think
fmt.Println("Set 'nothink' mode.")
case "format":
if len(args) < 3 || args[2] != "json" {
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")

30
cmd/templatefmt/main.go Normal file
View File

@ -0,0 +1,30 @@
package main
import (
"flag"
"fmt"
"io/ioutil"
"log"
"os"
"github.com/ollama/ollama/template"
)
func main() {
flag.Parse()
if flag.NArg() != 1 {
fmt.Fprintf(os.Stderr, "usage: %s <template.gotmpl>\n", os.Args[0])
os.Exit(2)
}
path := flag.Arg(0)
data, err := ioutil.ReadFile(path)
if err != nil {
log.Fatal(err)
}
out, err := template.Format(string(data))
if err != nil {
log.Fatal(err)
}
fmt.Print(out)
}

64
cmd/warn_thinking_test.go Normal file
View File

@ -0,0 +1,64 @@
package cmd
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
// Test that a warning is printed when thinking is requested but not supported.
func TestWarnMissingThinking(t *testing.T) {
cases := []struct {
capabilities []model.Capability
expectWarn bool
}{
{capabilities: []model.Capability{model.CapabilityThinking}, expectWarn: false},
{capabilities: []model.Capability{}, expectWarn: true},
}
for _, tc := range cases {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/show" || r.Method != http.MethodPost {
t.Fatalf("unexpected request to %s %s", r.URL.Path, r.Method)
}
var req api.ShowRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
resp := api.ShowResponse{Capabilities: tc.capabilities}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("encode response: %v", err)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
client, err := api.ClientFromEnvironment()
if err != nil {
t.Fatal(err)
}
oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
warnMissingThinking(context.Background(), client, "m")
w.Close()
os.Stderr = oldStderr
out, _ := io.ReadAll(r)
warned := strings.Contains(string(out), "warning:")
if tc.expectWarn && !warned {
t.Errorf("expected warning, got none")
}
if !tc.expectWarn && warned {
t.Errorf("did not expect warning, got: %s", string(out))
}
}
}

View File

@ -61,6 +61,8 @@ const (
ColorGrey = Esc + "[38;5;245m"
ColorDefault = Esc + "[0m"
ColorBold = Esc + "[1m"
StartBracketedPaste = Esc + "[?2004h"
EndBracketedPaste = Esc + "[?2004l"
)

View File

@ -37,6 +37,7 @@ var (
errCapabilityInsert = errors.New("insert")
errCapabilityVision = errors.New("vision")
errCapabilityEmbedding = errors.New("embedding")
errCapabilityThinking = errors.New("thinking")
errInsecureProtocol = errors.New("insecure protocol http")
)
@ -106,6 +107,12 @@ func (m *Model) Capabilities() []model.Capability {
capabilities = append(capabilities, model.CapabilityInsert)
}
// Check for thinking capability
openingTag, closingTag := inferThinkingTags(m.Template.Template)
if openingTag != "" && closingTag != "" {
capabilities = append(capabilities, model.CapabilityThinking)
}
return capabilities
}
@ -122,6 +129,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
model.CapabilityInsert: errCapabilityInsert,
model.CapabilityVision: errCapabilityVision,
model.CapabilityEmbedding: errCapabilityEmbedding,
model.CapabilityThinking: errCapabilityThinking,
}
for _, cap := range want {

View File

@ -22,7 +22,7 @@ var errTooManyImages = errors.New("vision model only supports a single image per
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *bool) (prompt string, images []llm.ImageData, _ error) {
var system []api.Message
isMllama := checkMllamaModelFamily(m)
@ -56,8 +56,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
}
}
thinkVal := false
if think != nil {
thinkVal = *think
}
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil {
return "", nil, err
}
@ -142,7 +146,11 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
// truncate any messages that do not fit into the context window
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
thinkVal := false
if think != nil {
thinkVal = *think
}
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil {
return "", nil, err
}

View File

@ -318,7 +318,8 @@ func TestChatPrompt(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
model := tt.model
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
think := false
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &think)
if tt.error == nil && err != nil {
t.Fatal(err)
} else if tt.error != nil && err != tt.error {

View File

@ -18,7 +18,6 @@ import (
"os"
"os/signal"
"path/filepath"
"regexp"
"slices"
"strings"
"syscall"
@ -181,6 +180,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
if req.Suffix != "" {
caps = append(caps, model.CapabilityInsert)
}
if req.Think != nil {
// note that the capability is still required even if `Thinking` is false
// because turning off thinking requires the model to support it (e.g.,
// older qwen3 templates don't know how to turn off thinking)
caps = append(caps, model.CapabilityThinking)
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
@ -1475,6 +1480,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
if len(req.Tools) > 0 {
caps = append(caps, model.CapabilityTools)
}
if req.Think != nil {
caps = append(caps, model.CapabilityThinking)
}
name := model.ParseName(req.Model)
if !name.IsValid() {
@ -1515,7 +1523,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
msgs = filterThinkTags(msgs, m)
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools, req.Think)
if err != nil {
slog.Error("chat prompt error", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@ -1524,6 +1532,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
slog.Debug("chat request", "images", len(images), "prompt", prompt)
var thinkingState *thinkingParser
openingTag, closingTag := inferThinkingTags(m.Template.Template)
if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" {
thinkingState = &thinkingParser{
openingTag: openingTag,
closingTag: closingTag,
}
}
ch := make(chan any)
go func() {
defer close(ch)
@ -1548,6 +1565,20 @@ func (s *Server) ChatHandler(c *gin.Context) {
},
}
if thinkingState != nil {
if openingTag == "" || closingTag == "" {
// TODO(drifkin): put warning here
} else {
thinkingContent, remainingContent := thinkingState.addContent(res.Message.Content)
if thinkingContent == "" && remainingContent == "" && !r.Done {
// need to accumulate more to decide what to send
return
}
res.Message.Content = remainingContent
res.Message.Thinking = thinkingContent
}
}
if r.Done {
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
@ -1565,7 +1596,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
// Streaming tool calls:
// If tools are recognized, use a flag to track the sending of a tool downstream
// This ensures that content is cleared from the message on the last chunk sent
sb.WriteString(r.Content)
sb.WriteString(res.Message.Content)
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
res.Message.ToolCalls = toolCalls
for i := range toolCalls {
@ -1613,9 +1644,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
resp.Message.Content = sb.String()
if req.Think != nil && *req.Think {
resp.Message.Thinking, resp.Message.Content = extractThinking(resp.Message.Content)
}
if len(req.Tools) > 0 {
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
if toolCalls, ok := m.parseToolCalls(resp.Message.Content); ok {
resp.Message.ToolCalls = toolCalls
resp.Message.Content = ""
}
@ -1643,7 +1677,16 @@ func handleScheduleError(c *gin.Context, name string, err error) {
}
}
var thinkTagRegexp = regexp.MustCompile(`<think>(?s).*?</think>(\n)*`)
// returns (thinkingContent, content)
func extractThinking(text string) (string, string) {
thinking := thinkingParser{
openingTag: "<think>",
closingTag: "</think>",
}
thinkingContent, content := thinking.addContent(text)
return thinkingContent, content
}
func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
@ -1656,7 +1699,9 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
for i, msg := range msgs {
if msg.Role == "assistant" && i < finalUserIndex {
msgs[i].Content = thinkTagRegexp.ReplaceAllString(msg.Content, "")
thinkingContent, content := extractThinking(msg.Content)
msg.Content = content
msg.Thinking = thinkingContent
}
}
}

View File

@ -143,6 +143,25 @@ func TestGenerateChat(t *testing.T) {
}
})
t.Run("missing thinking capability", func(t *testing.T) {
think := true
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Think: &think,
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support thinking"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing model", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
if w.Code != http.StatusBadRequest {

256
server/thinking.go Normal file
View File

@ -0,0 +1,256 @@
package server
import (
"strings"
"text/template"
"text/template/parse"
"unicode"
)
type thinkingParseState int
const (
thinkingParseState_LookingForOpening thinkingParseState = iota
thinkingParseState_Thinking
thinkingParseState_ThinkingDone
)
func (s thinkingParseState) String() string {
switch s {
case thinkingParseState_LookingForOpening:
return "LookingForOpening"
case thinkingParseState_Thinking:
return "Thinking"
case thinkingParseState_ThinkingDone:
return "ThinkingDone"
default:
return "Unknown"
}
}
type thinkingParser struct {
state thinkingParseState
openingTag string
closingTag string
acc strings.Builder
}
// returns the thinking content and the normal content that should be
// immediately sent to the user. It will internally buffer if it needs to see
// more content to disambiguate
func (s *thinkingParser) addContent(content string) (string, string) {
s.acc.WriteString(content)
var thinkingAcc, remainingAcc strings.Builder
var thinking, remaining string
keepLooping := true
// we loop because we might pass through multiple parsing states in a single
// call to addContent, and we want to make sure callers don't have to wait for
// data that's already unambiguous
for keepLooping {
thinking, remaining, keepLooping = eat(s)
thinkingAcc.WriteString(thinking)
remainingAcc.WriteString(remaining)
}
return thinkingAcc.String(), remainingAcc.String()
}
// the additional bool return is true iff we should continue eating
func eat(s *thinkingParser) (string, string, bool) {
switch s.state {
case thinkingParseState_LookingForOpening:
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
if strings.HasPrefix(trimmed, s.openingTag) {
after := strings.Join(strings.Split(trimmed, s.openingTag)[1:], s.openingTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
// after might contain more than just thinking tokens, so we continue
// parsing instead of returning it as thinking tokens here
s.acc.Reset()
s.acc.WriteString(after)
s.state = thinkingParseState_Thinking
return "", "", true
} else if strings.HasPrefix(s.openingTag, trimmed) {
// partial opening seen, so let's keep accumulating
return "", "", false
} else if trimmed == "" {
// saw whitespace only, so let's keep accumulating
return "", "", false
} else {
// didn't see an opening tag, but we have content, so thinking was skipped
s.state = thinkingParseState_ThinkingDone
// note that we use the original content, not the trimmed one because we
// don't want to eat any whitespace in the real content if there were no
// thinking tags
return "", s.acc.String(), false
}
case thinkingParseState_Thinking:
acc := s.acc.String()
if strings.Contains(acc, s.closingTag) {
split := strings.Split(acc, s.closingTag)
thinking := split[0]
remaining := strings.Join(split[1:], s.closingTag)
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
s.acc.Reset()
s.state = thinkingParseState_ThinkingDone
return thinking, remaining, false
} else if overlapLen := overlap(acc, s.closingTag); overlapLen > 0 {
thinking := acc[:len(acc)-overlapLen]
remaining := acc[len(acc)-overlapLen:]
s.acc.Reset()
// keep track of the candidate closing tag. We have to buffer it until it
// becomes disambiguated
s.acc.WriteString(remaining)
return thinking, "", false
} else {
// purely just thinking tokens, so we can return them
s.acc.Reset()
return acc, "", false
}
case thinkingParseState_ThinkingDone:
acc := s.acc.String()
s.acc.Reset()
return "", acc, false
default:
panic("unknown state")
}
}
// longest overlap between suffix of s and prefix of delim
func overlap(s, delim string) int {
max := min(len(delim), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, delim[:i]) {
return i
}
}
return 0
}
func templateVisit(n parse.Node, enterFn func(parse.Node) bool, exitFn func(parse.Node)) {
if n == nil {
return
}
shouldContinue := enterFn(n)
if !shouldContinue {
return
}
switch x := n.(type) {
case *parse.ListNode:
for _, c := range x.Nodes {
templateVisit(c, enterFn, exitFn)
}
case *parse.BranchNode:
if x.Pipe != nil {
templateVisit(x.Pipe, enterFn, exitFn)
}
if x.List != nil {
templateVisit(x.List, enterFn, exitFn)
}
if x.ElseList != nil {
templateVisit(x.ElseList, enterFn, exitFn)
}
case *parse.ActionNode:
templateVisit(x.Pipe, enterFn, exitFn)
case *parse.WithNode:
templateVisit(&x.BranchNode, enterFn, exitFn)
case *parse.RangeNode:
templateVisit(&x.BranchNode, enterFn, exitFn)
case *parse.IfNode:
templateVisit(&x.BranchNode, enterFn, exitFn)
case *parse.TemplateNode:
templateVisit(x.Pipe, enterFn, exitFn)
case *parse.PipeNode:
for _, c := range x.Cmds {
templateVisit(c, enterFn, exitFn)
}
case *parse.CommandNode:
for _, a := range x.Args {
templateVisit(a, enterFn, exitFn)
}
// text, field, number, etc. are leaves nothing to recurse into
}
if exitFn != nil {
exitFn(n)
}
}
// We use a heuristic to infer the tags that surround thinking traces:
// We look for a range node that iterates over "Messages" and then look for a
// reference to "Thinking" like `{{.Thinking}}`. We then go up to the nearest
// ListNode and take the first and last TextNodes as the opening and closing
// tags.
func inferThinkingTags(t *template.Template) (string, string) {
ancestors := []parse.Node{}
openingTag := ""
closingTag := ""
enterFn := func(n parse.Node) bool {
ancestors = append(ancestors, n)
switch x := n.(type) {
case *parse.FieldNode:
if len(x.Ident) > 0 && x.Ident[0] == "Thinking" {
var mostRecentRange *parse.RangeNode
for i := len(ancestors) - 1; i >= 0; i-- {
if r, ok := ancestors[i].(*parse.RangeNode); ok {
mostRecentRange = r
break
}
}
if mostRecentRange == nil || !rangeUsesField(mostRecentRange, "Messages") {
return true
}
// TODO(drifkin): to be more robust, check that it's in the action
// part, not the `if`'s pipeline part. We do match on the nearest list
// that starts and ends with text nodes, which makes this not strictly
// necessary for our heuristic
// go up to the nearest ancestor that is a *parse.ListNode
for i := len(ancestors) - 1; i >= 0; i-- {
if l, ok := ancestors[i].(*parse.ListNode); ok {
firstNode := l.Nodes[0]
if t, ok := firstNode.(*parse.TextNode); ok {
openingTag = strings.TrimSpace(t.String())
}
lastNode := l.Nodes[len(l.Nodes)-1]
if t, ok := lastNode.(*parse.TextNode); ok {
closingTag = strings.TrimSpace(t.String())
}
break
}
}
}
}
return true
}
exitFn := func(n parse.Node) {
ancestors = ancestors[:len(ancestors)-1]
}
templateVisit(t.Root, enterFn, exitFn)
return openingTag, closingTag
}
// checks to see if the given field name is present in the pipeline of the given range node
func rangeUsesField(rangeNode *parse.RangeNode, field string) bool {
found := false
enterFn := func(n parse.Node) bool {
switch x := n.(type) {
case *parse.FieldNode:
if x.Ident[0] == field {
found = true
}
}
return true
}
templateVisit(rangeNode.BranchNode.Pipe, enterFn, nil)
return found
}

286
server/thinking_test.go Normal file
View File

@ -0,0 +1,286 @@
package server
import (
"testing"
"text/template"
)
func TestExtractThinking(t *testing.T) {
tests := []struct {
in, wantContent, wantThink string
}{
{
in: "<think> internal </think> world",
wantThink: "internal ",
wantContent: "world",
},
{
in: "<think>a</think><think>b</think>c",
wantThink: "a",
wantContent: "<think>b</think>c",
},
{
in: "no think",
wantThink: "",
wantContent: "no think",
},
}
for i, tt := range tests {
gotThinking, gotContent := extractThinking(tt.in)
if gotContent != tt.wantContent || gotThinking != tt.wantThink {
t.Errorf("case %d: got (%q,%q), want (%q,%q)", i, gotThinking, gotContent, tt.wantThink, tt.wantContent)
}
}
}
func TestThinkingStreaming(t *testing.T) {
type step struct {
input string
wantThinking string
wantContent string
wantStateAfter thinkingParseState
}
cases := []struct {
desc string
skip bool
steps []step
}{
{
desc: "content without a thinking tag",
steps: []step{
{
input: " abc",
wantThinking: "",
wantContent: " abc",
wantStateAfter: thinkingParseState_ThinkingDone,
},
},
},
{
desc: "content before a thinking tag nerfs the thinking tag",
steps: []step{
{
input: " abc <think>def</think> ghi",
wantThinking: "",
wantContent: " abc <think>def</think> ghi",
wantStateAfter: thinkingParseState_ThinkingDone,
},
},
},
{
desc: "building up a thinking tag partially",
// skip: true,
steps: []step{
{
input: " <th",
wantThinking: "",
wantContent: "",
wantStateAfter: thinkingParseState_LookingForOpening,
},
{
input: "in",
wantThinking: "",
wantContent: "",
wantStateAfter: thinkingParseState_LookingForOpening,
},
{
input: "k>a",
wantThinking: "a",
wantContent: "",
wantStateAfter: thinkingParseState_Thinking,
},
},
},
{
desc: "partial closing tag",
steps: []step{
{
input: "<think>abc</th",
wantThinking: "abc",
wantContent: "",
wantStateAfter: thinkingParseState_Thinking,
},
{
input: "ink>def",
wantThinking: "",
wantContent: "def",
wantStateAfter: thinkingParseState_ThinkingDone,
},
},
},
{
desc: "partial closing tag fakeout",
steps: []step{
{
input: "<think>abc</th",
wantThinking: "abc",
wantContent: "",
wantStateAfter: thinkingParseState_Thinking,
},
{
input: "ing>def",
wantThinking: "</thing>def",
wantContent: "",
wantStateAfter: thinkingParseState_Thinking,
},
{
input: "ghi</thi",
wantThinking: "ghi",
wantContent: "",
wantStateAfter: thinkingParseState_Thinking,
},
{
input: "nk>jkl",
wantThinking: "",
wantContent: "jkl",
wantStateAfter: thinkingParseState_ThinkingDone,
},
},
},
}
for _, c := range cases {
parser := thinkingParser{
openingTag: "<think>",
closingTag: "</think>",
}
if c.skip {
continue
}
for i, step := range c.steps {
thinking, content := parser.addContent(step.input)
if content != step.wantContent || thinking != step.wantThinking {
t.Errorf("case %q (step %d): got (%q,%q), want (%q,%q)", c.desc, i, content, thinking, step.wantContent, step.wantThinking)
}
if parser.state != step.wantStateAfter {
t.Errorf("case %q (step %d): got state %s, want %s", c.desc, i, parser.state.String(), step.wantStateAfter.String())
}
}
}
}
func TestInferThinkingTags(t *testing.T) {
cases := []struct {
desc string
tmplString string
wantOpeningTag string
wantClosingTag string
}{
{
desc: "basic",
tmplString: `
{{ if .Thinking}}
/think
{{ end }}
{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
{{ if and $last .Thinking }}
<think>{{ .Thinking }}</think>
{{ end }}
{{ end }}
`,
wantOpeningTag: "<think>",
wantClosingTag: "</think>",
},
{
desc: "doubly nested range",
tmplString: `
{{ if .Thinking}}
/think
{{ end }}
{{- range $i, $_ := .Messages }}
{{- range $j, $_ := .NotMessages }}
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
{{ if and $last .Thinking }}
<think>{{ .Thinking }}</think>
{{ end }}
{{ end }}
{{ end }}
`,
wantOpeningTag: "",
wantClosingTag: "",
},
{
desc: "whitespace is trimmed",
tmplString: `
{{ if .Thinking}}
/think
{{ end }}
{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
{{ if and $last .Thinking }}
Some text before {{ .Thinking }} Some text after
{{ end }}
{{ end }}
`,
wantOpeningTag: "Some text before",
wantClosingTag: "Some text after",
},
{
desc: "qwen3",
tmplString: `
{{- if or .System .Tools .Thinking }}<|im_start|>system
{{- if .System }}
{{ .System }}
{{- end }}
{{- if .Tools }}
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{{- range .Tools }}
{"type": "function", "function": {{ .Function }}}
{{- end }}
</tools>
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
{{- end }}
{{- if .Thinking }}
/think
{{- else }}
/no_think
{{- end }}<|im_end|>
{{ end }}
{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
{{- if eq .Role "user" }}<|im_start|>user
{{ .Content }}<|im_end|>
{{ else if eq .Role "assistant" }}<|im_start|>assistant
{{ if and $last .Thinking }}
<think>{{ .Thinking }}</think>
{{ end }}
{{ if .Content }}{{ .Content }}
{{- else if .ToolCalls }}<tool_call>
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{ end }}</tool_call>
{{- end }}{{ if not $last }}<|im_end|>
{{ end }}
{{- else if eq .Role "tool" }}<|im_start|>user
<tool_response>
{{ .Content }}
</tool_response><|im_end|>
{{ end }}
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
{{ end }}
{{- end }}
`,
wantOpeningTag: "<think>",
wantClosingTag: "</think>",
},
}
for _, c := range cases {
tmpl := template.Must(template.New("test").Parse(c.tmplString))
openingTag, closingTag := inferThinkingTags(tmpl)
if openingTag != c.wantOpeningTag || closingTag != c.wantClosingTag {
t.Errorf("case %q: got (%q,%q), want (%q,%q)", c.desc, openingTag, closingTag, c.wantOpeningTag, c.wantClosingTag)
}
}
}

101
template/pretty.go Normal file
View File

@ -0,0 +1,101 @@
package template
import (
"strings"
texttmpl "text/template"
"text/template/parse"
)
// Format returns a human-readable representation of the template.
// The formatting indents nested sections such as if/else blocks.
func Format(src string) (string, error) {
tmpl, err := texttmpl.New("pretty").Parse(src)
if err != nil {
return "", err
}
var sb strings.Builder
printNodes(tmpl.Tree.Root, 0, &sb)
return sb.String(), nil
}
func indent(sb *strings.Builder, level int) {
for i := 0; i < level; i++ {
sb.WriteString(" ")
}
}
func printNodes(list *parse.ListNode, level int, sb *strings.Builder) {
if list == nil {
return
}
for _, n := range list.Nodes {
printNode(n, level, sb)
}
}
func printNode(n parse.Node, level int, sb *strings.Builder) {
switch n := n.(type) {
case *parse.TextNode:
text := strings.TrimSpace(string(n.Text))
if text == "" {
return
}
indent(sb, level)
sb.WriteString(text)
sb.WriteByte('\n')
case *parse.ActionNode:
indent(sb, level)
// sb.WriteString("ACTION {{ ")
sb.WriteString(n.String())
// sb.WriteString(" }}\n")
sb.WriteByte('\n')
case *parse.IfNode:
indent(sb, level)
sb.WriteString("{{ if ")
sb.WriteString(n.Pipe.String())
sb.WriteString(" }}\n")
printNodes(n.List, level+1, sb)
if n.ElseList != nil {
indent(sb, level)
sb.WriteString("{{ else }}\n")
printNodes(n.ElseList, level+1, sb)
}
indent(sb, level)
sb.WriteString("{{ end }}\n")
case *parse.RangeNode:
indent(sb, level)
sb.WriteString("{{ range ")
sb.WriteString(n.Pipe.String())
sb.WriteString(" }}\n")
printNodes(n.List, level+1, sb)
if n.ElseList != nil {
indent(sb, level)
sb.WriteString("{{ else }}\n")
printNodes(n.ElseList, level+1, sb)
}
indent(sb, level)
sb.WriteString("{{ end }}\n")
case *parse.WithNode:
indent(sb, level)
sb.WriteString("{{ with ")
sb.WriteString(n.Pipe.String())
sb.WriteString(" }}\n")
printNodes(n.List, level+1, sb)
if n.ElseList != nil {
indent(sb, level)
sb.WriteString("{{ else }}\n")
printNodes(n.ElseList, level+1, sb)
}
indent(sb, level)
sb.WriteString("{{ end }}\n")
case *parse.TemplateNode:
indent(sb, level)
sb.WriteString("{{ template ")
sb.WriteString(n.Name)
sb.WriteString(" }}\n")
default:
indent(sb, level)
sb.WriteString(n.String())
sb.WriteByte('\n')
}
}

30
template/pretty_test.go Normal file
View File

@ -0,0 +1,30 @@
package template
import (
"strings"
"testing"
)
func TestFormatIndentation(t *testing.T) {
tmpl := "{{ if .Cond }}A{{ else }}B{{ end }}"
out, err := Format(tmpl)
if err != nil {
t.Fatal(err)
}
expectedLines := []string{
"{{ if .Cond }}",
" A",
"{{ else }}",
" B",
"{{ end }}",
}
got := strings.Split(strings.TrimSpace(out), "\n")
if len(got) != len(expectedLines) {
t.Fatalf("expected %d lines, got %d: %q", len(expectedLines), len(got), out)
}
for i, line := range expectedLines {
if strings.TrimSpace(got[i]) != strings.TrimSpace(line) {
t.Errorf("line %d = %q, want %q", i, got[i], line)
}
}
}

View File

@ -167,6 +167,10 @@ type Values struct {
api.Tools
Prompt string
Suffix string
Think bool
// whether or not the user explicitly set the thinking flag (vs. it being
// implicitly false). Templates can't see whether `Think` is nil
IsThinkSet bool
// forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy bool
@ -222,16 +226,20 @@ func (t *Template) Execute(w io.Writer, v Values) error {
system, messages := collate(v.Messages)
if v.Prompt != "" && v.Suffix != "" {
return t.Template.Execute(w, map[string]any{
"Prompt": v.Prompt,
"Suffix": v.Suffix,
"Response": "",
"Prompt": v.Prompt,
"Suffix": v.Suffix,
"Response": "",
"Think": v.Think,
"IsThinkSet": v.IsThinkSet,
})
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": messages,
"Tools": v.Tools,
"Response": "",
"System": system,
"Messages": messages,
"Tools": v.Tools,
"Response": "",
"Think": v.Think,
"IsThinkSet": v.IsThinkSet,
})
}
@ -241,9 +249,11 @@ func (t *Template) Execute(w io.Writer, v Values) error {
for _, m := range messages {
execute := func() error {
if err := t.Template.Execute(&b, map[string]any{
"System": system,
"Prompt": prompt,
"Response": response,
"System": system,
"Prompt": prompt,
"Response": response,
"Think": v.Think,
"IsThinkSet": v.IsThinkSet,
}); err != nil {
return err
}
@ -286,9 +296,11 @@ func (t *Template) Execute(w io.Writer, v Values) error {
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
"System": system,
"Prompt": prompt,
"Response": response,
"System": system,
"Prompt": prompt,
"Response": response,
"Think": v.Think,
"IsThinkSet": v.IsThinkSet,
}); err != nil {
return err
}

View File

@ -8,6 +8,7 @@ const (
CapabilityInsert = Capability("insert")
CapabilityVision = Capability("vision")
CapabilityEmbedding = Capability("embedding")
CapabilityThinking = Capability("thinking")
)
func (c Capability) String() string {