WIP thinking API support
- Allows specifying whether thinking mode should be on or not - Templates get passed a new option so, e.g., qwen3's template can put `/think` or `/no_think` in the system prompt depending on the value of the setting - Add parsing for thinking blocks in both streaming/non-streaming mode - Update the CLI to make use of these changes
This commit is contained in:
parent
a7835c6716
commit
bc8abf7917
15
api/types.go
15
api/types.go
@ -83,6 +83,12 @@ type GenerateRequest struct {
|
||||
// Options lists model-specific options. For example, temperature can be
|
||||
// set through this field, if the model supports it.
|
||||
Options map[string]any `json:"options"`
|
||||
|
||||
// Think controls whether thinking/reasoning models will think before
|
||||
// responding. Needs to be a pointer so we can distinguish between false
|
||||
// (request that thinking _not_ be used) and unset (use the old behavior
|
||||
// before this option was introduced)
|
||||
Think *bool `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@ -108,6 +114,10 @@ type ChatRequest struct {
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]any `json:"options"`
|
||||
|
||||
// Think controls whether thinking/reasoning models will think before
|
||||
// responding
|
||||
Think *bool `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
@ -128,6 +138,9 @@ func (t Tool) String() string {
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
// Thinking contains the text that was inside thinking tags in the
|
||||
// original model output when ChatRequest.Think is enabled.
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
@ -275,6 +288,8 @@ type Options struct {
|
||||
MirostatTau float32 `json:"mirostat_tau,omitempty"`
|
||||
MirostatEta float32 `json:"mirostat_eta,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
|
||||
Think bool `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
// Runner options which must be set when the model is loaded into memory
|
||||
|
@ -372,3 +372,50 @@ func TestPropertyType_MarshalJSON(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinking_UnmarshalJSON(t *testing.T) {
|
||||
trueVal := true
|
||||
falseVal := false
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedThinking *bool
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "true",
|
||||
input: `{ "think": true }`,
|
||||
expectedThinking: &trueVal,
|
||||
},
|
||||
{
|
||||
name: "false",
|
||||
input: `{ "think": false }`,
|
||||
expectedThinking: &falseVal,
|
||||
},
|
||||
{
|
||||
name: "unset",
|
||||
input: `{ }`,
|
||||
expectedThinking: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid",
|
||||
input: `{ "think": "true" }`,
|
||||
expectedThinking: nil,
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var req GenerateRequest
|
||||
err := json.Unmarshal([]byte(test.input), &req)
|
||||
if test.expectedError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expectedThinking, req.Think)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
67
cmd/cmd.go
67
cmd/cmd.go
@ -38,12 +38,32 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/runner"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
// warnMissingThinking emits a warning if the model does not advertise thinking
|
||||
// support and opts.Thinking is set. Failures to query the capability are
|
||||
// ignored so this does not impact regular usage.
|
||||
func warnMissingThinking(ctx context.Context, client *api.Client, name string) {
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, cap := range resp.Capabilities {
|
||||
if cap == model.CapabilityThinking {
|
||||
return
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
|
||||
}
|
||||
|
||||
var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
|
||||
|
||||
func getModelfileName(cmd *cobra.Command) (string, error) {
|
||||
@ -240,9 +260,18 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
return err
|
||||
}
|
||||
|
||||
think := opts.Think
|
||||
if think == nil {
|
||||
falseVal := false
|
||||
think = &falseVal
|
||||
}
|
||||
|
||||
req := &api.GenerateRequest{
|
||||
Model: opts.Model,
|
||||
KeepAlive: opts.KeepAlive,
|
||||
|
||||
// pass Think here so we fail before getting to the chat prompt if the model doesn't support it
|
||||
Think: opts.Think,
|
||||
}
|
||||
|
||||
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
|
||||
@ -277,6 +306,17 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
opts.Format = format
|
||||
|
||||
thinkFlag := cmd.Flags().Lookup("think")
|
||||
if thinkFlag.Changed {
|
||||
think, err := cmd.Flags().GetBool("think")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.Think = &think
|
||||
} else {
|
||||
opts.Think = nil
|
||||
}
|
||||
|
||||
keepAlive, err := cmd.Flags().GetString("keepalive")
|
||||
if err != nil {
|
||||
return err
|
||||
@ -361,6 +401,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
return err
|
||||
}
|
||||
warnMissingThinking(cmd.Context(), client, opts.Model)
|
||||
|
||||
for _, msg := range info.Messages {
|
||||
switch msg.Role {
|
||||
@ -876,6 +917,7 @@ type runOptions struct {
|
||||
Options map[string]any
|
||||
MultiModal bool
|
||||
KeepAlive *api.Duration
|
||||
Think *bool
|
||||
}
|
||||
|
||||
type displayResponseState struct {
|
||||
@ -958,6 +1000,8 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
var latest api.ChatResponse
|
||||
var fullResponse strings.Builder
|
||||
var role string
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
|
||||
fn := func(response api.ChatResponse) error {
|
||||
p.StopAndClear()
|
||||
@ -965,7 +1009,23 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
latest = response
|
||||
|
||||
role = response.Message.Role
|
||||
if response.Message.Thinking != "" {
|
||||
if !thinkTagOpened {
|
||||
fmt.Print(readline.ColorGrey + readline.ColorBold + "<think>" + readline.ColorDefault + readline.ColorGrey)
|
||||
thinkTagOpened = true
|
||||
}
|
||||
displayResponse(response.Message.Thinking, opts.WordWrap, state)
|
||||
}
|
||||
|
||||
content := response.Message.Content
|
||||
if !thinkTagClosed && thinkTagOpened && content != "" {
|
||||
fmt.Print(readline.ColorGrey + readline.ColorBold + "</think>" + readline.ColorDefault)
|
||||
thinkTagClosed = true
|
||||
}
|
||||
// purposefully not putting thinking blocks in the response, which would
|
||||
// only be needed if we later added tool calling to the cli (they get
|
||||
// filtered out anyway since current models don't expect them unless you're
|
||||
// about to finish some tool calls)
|
||||
fullResponse.WriteString(content)
|
||||
|
||||
displayResponse(content, opts.WordWrap, state)
|
||||
@ -982,6 +1042,11 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
Messages: opts.Messages,
|
||||
Format: json.RawMessage(opts.Format),
|
||||
Options: opts.Options,
|
||||
Think: opts.Think,
|
||||
}
|
||||
|
||||
if opts.Think != nil {
|
||||
warnMissingThinking(cmd.Context(), client, opts.Model)
|
||||
}
|
||||
|
||||
if opts.KeepAlive != nil {
|
||||
@ -1075,6 +1140,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
System: opts.System,
|
||||
Options: opts.Options,
|
||||
KeepAlive: opts.KeepAlive,
|
||||
Think: opts.Think,
|
||||
}
|
||||
|
||||
if err := client.Generate(ctx, &request, fn); err != nil {
|
||||
@ -1290,6 +1356,7 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
||||
runCmd.Flags().String("format", "", "Response format (e.g. json)")
|
||||
runCmd.Flags().Bool("think", false, "Turn on thinking mode for supported models")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop MODEL",
|
||||
|
@ -62,6 +62,8 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
|
||||
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
|
||||
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
|
||||
fmt.Fprintln(os.Stderr, " /set think Enable thinking")
|
||||
fmt.Fprintln(os.Stderr, " /set nothink Disable thinking")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
@ -260,6 +262,17 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "think":
|
||||
think := true
|
||||
opts.Think = &think
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
warnMissingThinking(cmd.Context(), client, opts.Model)
|
||||
}
|
||||
fmt.Println("Set 'think' mode.")
|
||||
case "nothink":
|
||||
think := false
|
||||
opts.Think = &think
|
||||
fmt.Println("Set 'nothink' mode.")
|
||||
case "format":
|
||||
if len(args) < 3 || args[2] != "json" {
|
||||
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
|
||||
|
30
cmd/templatefmt/main.go
Normal file
30
cmd/templatefmt/main.go
Normal file
@ -0,0 +1,30 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
if flag.NArg() != 1 {
|
||||
fmt.Fprintf(os.Stderr, "usage: %s <template.gotmpl>\n", os.Args[0])
|
||||
os.Exit(2)
|
||||
}
|
||||
path := flag.Arg(0)
|
||||
data, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
out, err := template.Format(string(data))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Print(out)
|
||||
}
|
64
cmd/warn_thinking_test.go
Normal file
64
cmd/warn_thinking_test.go
Normal file
@ -0,0 +1,64 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// Test that a warning is printed when thinking is requested but not supported.
|
||||
func TestWarnMissingThinking(t *testing.T) {
|
||||
cases := []struct {
|
||||
capabilities []model.Capability
|
||||
expectWarn bool
|
||||
}{
|
||||
{capabilities: []model.Capability{model.CapabilityThinking}, expectWarn: false},
|
||||
{capabilities: []model.Capability{}, expectWarn: true},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/show" || r.Method != http.MethodPost {
|
||||
t.Fatalf("unexpected request to %s %s", r.URL.Path, r.Method)
|
||||
}
|
||||
var req api.ShowRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
resp := api.ShowResponse{Capabilities: tc.capabilities}
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
t.Fatalf("encode response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
oldStderr := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
warnMissingThinking(context.Background(), client, "m")
|
||||
w.Close()
|
||||
os.Stderr = oldStderr
|
||||
out, _ := io.ReadAll(r)
|
||||
|
||||
warned := strings.Contains(string(out), "warning:")
|
||||
if tc.expectWarn && !warned {
|
||||
t.Errorf("expected warning, got none")
|
||||
}
|
||||
if !tc.expectWarn && warned {
|
||||
t.Errorf("did not expect warning, got: %s", string(out))
|
||||
}
|
||||
}
|
||||
}
|
@ -61,6 +61,8 @@ const (
|
||||
ColorGrey = Esc + "[38;5;245m"
|
||||
ColorDefault = Esc + "[0m"
|
||||
|
||||
ColorBold = Esc + "[1m"
|
||||
|
||||
StartBracketedPaste = Esc + "[?2004h"
|
||||
EndBracketedPaste = Esc + "[?2004l"
|
||||
)
|
||||
|
@ -37,6 +37,7 @@ var (
|
||||
errCapabilityInsert = errors.New("insert")
|
||||
errCapabilityVision = errors.New("vision")
|
||||
errCapabilityEmbedding = errors.New("embedding")
|
||||
errCapabilityThinking = errors.New("thinking")
|
||||
errInsecureProtocol = errors.New("insecure protocol http")
|
||||
)
|
||||
|
||||
@ -106,6 +107,12 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities = append(capabilities, model.CapabilityInsert)
|
||||
}
|
||||
|
||||
// Check for thinking capability
|
||||
openingTag, closingTag := inferThinkingTags(m.Template.Template)
|
||||
if openingTag != "" && closingTag != "" {
|
||||
capabilities = append(capabilities, model.CapabilityThinking)
|
||||
}
|
||||
|
||||
return capabilities
|
||||
}
|
||||
|
||||
@ -122,6 +129,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
||||
model.CapabilityInsert: errCapabilityInsert,
|
||||
model.CapabilityVision: errCapabilityVision,
|
||||
model.CapabilityEmbedding: errCapabilityEmbedding,
|
||||
model.CapabilityThinking: errCapabilityThinking,
|
||||
}
|
||||
|
||||
for _, cap := range want {
|
||||
|
@ -22,7 +22,7 @@ var errTooManyImages = errors.New("vision model only supports a single image per
|
||||
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
||||
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||
// latest message and 2) system messages
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *bool) (prompt string, images []llm.ImageData, _ error) {
|
||||
var system []api.Message
|
||||
|
||||
isMllama := checkMllamaModelFamily(m)
|
||||
@ -56,8 +56,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
}
|
||||
}
|
||||
|
||||
thinkVal := false
|
||||
if think != nil {
|
||||
thinkVal = *think
|
||||
}
|
||||
var b bytes.Buffer
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
@ -142,7 +146,11 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
|
||||
// truncate any messages that do not fit into the context window
|
||||
var b bytes.Buffer
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
|
||||
thinkVal := false
|
||||
if think != nil {
|
||||
thinkVal = *think
|
||||
}
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
|
@ -318,7 +318,8 @@ func TestChatPrompt(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
model := tt.model
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
|
||||
think := false
|
||||
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &think)
|
||||
if tt.error == nil && err != nil {
|
||||
t.Fatal(err)
|
||||
} else if tt.error != nil && err != tt.error {
|
||||
|
@ -18,7 +18,6 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
@ -181,6 +180,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
if req.Suffix != "" {
|
||||
caps = append(caps, model.CapabilityInsert)
|
||||
}
|
||||
if req.Think != nil {
|
||||
// note that the capability is still required even if `Thinking` is false
|
||||
// because turning off thinking requires the model to support it (e.g.,
|
||||
// older qwen3 templates don't know how to turn off thinking)
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||
if errors.Is(err, errCapabilityCompletion) {
|
||||
@ -1475,6 +1480,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
if len(req.Tools) > 0 {
|
||||
caps = append(caps, model.CapabilityTools)
|
||||
}
|
||||
if req.Think != nil {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
@ -1515,7 +1523,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
msgs = filterThinkTags(msgs, m)
|
||||
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools, req.Think)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@ -1524,6 +1532,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
|
||||
slog.Debug("chat request", "images", len(images), "prompt", prompt)
|
||||
|
||||
var thinkingState *thinkingParser
|
||||
openingTag, closingTag := inferThinkingTags(m.Template.Template)
|
||||
if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" {
|
||||
thinkingState = &thinkingParser{
|
||||
openingTag: openingTag,
|
||||
closingTag: closingTag,
|
||||
}
|
||||
}
|
||||
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
@ -1548,6 +1565,20 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
},
|
||||
}
|
||||
|
||||
if thinkingState != nil {
|
||||
if openingTag == "" || closingTag == "" {
|
||||
// TODO(drifkin): put warning here
|
||||
} else {
|
||||
thinkingContent, remainingContent := thinkingState.addContent(res.Message.Content)
|
||||
if thinkingContent == "" && remainingContent == "" && !r.Done {
|
||||
// need to accumulate more to decide what to send
|
||||
return
|
||||
}
|
||||
res.Message.Content = remainingContent
|
||||
res.Message.Thinking = thinkingContent
|
||||
}
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
res.DoneReason = r.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
@ -1565,7 +1596,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
// Streaming tool calls:
|
||||
// If tools are recognized, use a flag to track the sending of a tool downstream
|
||||
// This ensures that content is cleared from the message on the last chunk sent
|
||||
sb.WriteString(r.Content)
|
||||
sb.WriteString(res.Message.Content)
|
||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||
res.Message.ToolCalls = toolCalls
|
||||
for i := range toolCalls {
|
||||
@ -1613,9 +1644,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
resp.Message.Content = sb.String()
|
||||
if req.Think != nil && *req.Think {
|
||||
resp.Message.Thinking, resp.Message.Content = extractThinking(resp.Message.Content)
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||
if toolCalls, ok := m.parseToolCalls(resp.Message.Content); ok {
|
||||
resp.Message.ToolCalls = toolCalls
|
||||
resp.Message.Content = ""
|
||||
}
|
||||
@ -1643,7 +1677,16 @@ func handleScheduleError(c *gin.Context, name string, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
var thinkTagRegexp = regexp.MustCompile(`<think>(?s).*?</think>(\n)*`)
|
||||
// returns (thinkingContent, content)
|
||||
func extractThinking(text string) (string, string) {
|
||||
thinking := thinkingParser{
|
||||
openingTag: "<think>",
|
||||
closingTag: "</think>",
|
||||
}
|
||||
|
||||
thinkingContent, content := thinking.addContent(text)
|
||||
return thinkingContent, content
|
||||
}
|
||||
|
||||
func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
||||
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
|
||||
@ -1656,7 +1699,9 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
||||
|
||||
for i, msg := range msgs {
|
||||
if msg.Role == "assistant" && i < finalUserIndex {
|
||||
msgs[i].Content = thinkTagRegexp.ReplaceAllString(msg.Content, "")
|
||||
thinkingContent, content := extractThinking(msg.Content)
|
||||
msg.Content = content
|
||||
msg.Thinking = thinkingContent
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -143,6 +143,25 @@ func TestGenerateChat(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing thinking capability", func(t *testing.T) {
|
||||
think := true
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
Think: &think,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support thinking"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing model", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
|
||||
if w.Code != http.StatusBadRequest {
|
||||
|
256
server/thinking.go
Normal file
256
server/thinking.go
Normal file
@ -0,0 +1,256 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"text/template"
|
||||
"text/template/parse"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
type thinkingParseState int
|
||||
|
||||
const (
|
||||
thinkingParseState_LookingForOpening thinkingParseState = iota
|
||||
thinkingParseState_Thinking
|
||||
thinkingParseState_ThinkingDone
|
||||
)
|
||||
|
||||
func (s thinkingParseState) String() string {
|
||||
switch s {
|
||||
case thinkingParseState_LookingForOpening:
|
||||
return "LookingForOpening"
|
||||
case thinkingParseState_Thinking:
|
||||
return "Thinking"
|
||||
case thinkingParseState_ThinkingDone:
|
||||
return "ThinkingDone"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type thinkingParser struct {
|
||||
state thinkingParseState
|
||||
openingTag string
|
||||
closingTag string
|
||||
acc strings.Builder
|
||||
}
|
||||
|
||||
// returns the thinking content and the normal content that should be
|
||||
// immediately sent to the user. It will internally buffer if it needs to see
|
||||
// more content to disambiguate
|
||||
func (s *thinkingParser) addContent(content string) (string, string) {
|
||||
s.acc.WriteString(content)
|
||||
|
||||
var thinkingAcc, remainingAcc strings.Builder
|
||||
|
||||
var thinking, remaining string
|
||||
keepLooping := true
|
||||
// we loop because we might pass through multiple parsing states in a single
|
||||
// call to addContent, and we want to make sure callers don't have to wait for
|
||||
// data that's already unambiguous
|
||||
for keepLooping {
|
||||
thinking, remaining, keepLooping = eat(s)
|
||||
thinkingAcc.WriteString(thinking)
|
||||
remainingAcc.WriteString(remaining)
|
||||
}
|
||||
|
||||
return thinkingAcc.String(), remainingAcc.String()
|
||||
}
|
||||
|
||||
// the additional bool return is true iff we should continue eating
|
||||
func eat(s *thinkingParser) (string, string, bool) {
|
||||
switch s.state {
|
||||
case thinkingParseState_LookingForOpening:
|
||||
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, s.openingTag) {
|
||||
after := strings.Join(strings.Split(trimmed, s.openingTag)[1:], s.openingTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
// after might contain more than just thinking tokens, so we continue
|
||||
// parsing instead of returning it as thinking tokens here
|
||||
s.acc.Reset()
|
||||
s.acc.WriteString(after)
|
||||
s.state = thinkingParseState_Thinking
|
||||
return "", "", true
|
||||
} else if strings.HasPrefix(s.openingTag, trimmed) {
|
||||
// partial opening seen, so let's keep accumulating
|
||||
return "", "", false
|
||||
} else if trimmed == "" {
|
||||
// saw whitespace only, so let's keep accumulating
|
||||
return "", "", false
|
||||
} else {
|
||||
// didn't see an opening tag, but we have content, so thinking was skipped
|
||||
s.state = thinkingParseState_ThinkingDone
|
||||
// note that we use the original content, not the trimmed one because we
|
||||
// don't want to eat any whitespace in the real content if there were no
|
||||
// thinking tags
|
||||
return "", s.acc.String(), false
|
||||
}
|
||||
case thinkingParseState_Thinking:
|
||||
acc := s.acc.String()
|
||||
if strings.Contains(acc, s.closingTag) {
|
||||
split := strings.Split(acc, s.closingTag)
|
||||
thinking := split[0]
|
||||
remaining := strings.Join(split[1:], s.closingTag)
|
||||
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
s.acc.Reset()
|
||||
s.state = thinkingParseState_ThinkingDone
|
||||
return thinking, remaining, false
|
||||
} else if overlapLen := overlap(acc, s.closingTag); overlapLen > 0 {
|
||||
thinking := acc[:len(acc)-overlapLen]
|
||||
remaining := acc[len(acc)-overlapLen:]
|
||||
s.acc.Reset()
|
||||
// keep track of the candidate closing tag. We have to buffer it until it
|
||||
// becomes disambiguated
|
||||
s.acc.WriteString(remaining)
|
||||
return thinking, "", false
|
||||
} else {
|
||||
// purely just thinking tokens, so we can return them
|
||||
s.acc.Reset()
|
||||
return acc, "", false
|
||||
}
|
||||
case thinkingParseState_ThinkingDone:
|
||||
acc := s.acc.String()
|
||||
s.acc.Reset()
|
||||
return "", acc, false
|
||||
default:
|
||||
panic("unknown state")
|
||||
}
|
||||
}
|
||||
|
||||
// longest overlap between suffix of s and prefix of delim
|
||||
func overlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func templateVisit(n parse.Node, enterFn func(parse.Node) bool, exitFn func(parse.Node)) {
|
||||
if n == nil {
|
||||
return
|
||||
}
|
||||
shouldContinue := enterFn(n)
|
||||
if !shouldContinue {
|
||||
return
|
||||
}
|
||||
switch x := n.(type) {
|
||||
case *parse.ListNode:
|
||||
for _, c := range x.Nodes {
|
||||
templateVisit(c, enterFn, exitFn)
|
||||
}
|
||||
case *parse.BranchNode:
|
||||
if x.Pipe != nil {
|
||||
templateVisit(x.Pipe, enterFn, exitFn)
|
||||
}
|
||||
if x.List != nil {
|
||||
templateVisit(x.List, enterFn, exitFn)
|
||||
}
|
||||
if x.ElseList != nil {
|
||||
templateVisit(x.ElseList, enterFn, exitFn)
|
||||
}
|
||||
case *parse.ActionNode:
|
||||
templateVisit(x.Pipe, enterFn, exitFn)
|
||||
case *parse.WithNode:
|
||||
templateVisit(&x.BranchNode, enterFn, exitFn)
|
||||
case *parse.RangeNode:
|
||||
templateVisit(&x.BranchNode, enterFn, exitFn)
|
||||
case *parse.IfNode:
|
||||
templateVisit(&x.BranchNode, enterFn, exitFn)
|
||||
case *parse.TemplateNode:
|
||||
templateVisit(x.Pipe, enterFn, exitFn)
|
||||
case *parse.PipeNode:
|
||||
for _, c := range x.Cmds {
|
||||
templateVisit(c, enterFn, exitFn)
|
||||
}
|
||||
case *parse.CommandNode:
|
||||
for _, a := range x.Args {
|
||||
templateVisit(a, enterFn, exitFn)
|
||||
}
|
||||
// text, field, number, etc. are leaves – nothing to recurse into
|
||||
}
|
||||
if exitFn != nil {
|
||||
exitFn(n)
|
||||
}
|
||||
}
|
||||
|
||||
// We use a heuristic to infer the tags that surround thinking traces:
|
||||
// We look for a range node that iterates over "Messages" and then look for a
|
||||
// reference to "Thinking" like `{{.Thinking}}`. We then go up to the nearest
|
||||
// ListNode and take the first and last TextNodes as the opening and closing
|
||||
// tags.
|
||||
func inferThinkingTags(t *template.Template) (string, string) {
|
||||
ancestors := []parse.Node{}
|
||||
|
||||
openingTag := ""
|
||||
closingTag := ""
|
||||
|
||||
enterFn := func(n parse.Node) bool {
|
||||
ancestors = append(ancestors, n)
|
||||
|
||||
switch x := n.(type) {
|
||||
case *parse.FieldNode:
|
||||
if len(x.Ident) > 0 && x.Ident[0] == "Thinking" {
|
||||
var mostRecentRange *parse.RangeNode
|
||||
for i := len(ancestors) - 1; i >= 0; i-- {
|
||||
if r, ok := ancestors[i].(*parse.RangeNode); ok {
|
||||
mostRecentRange = r
|
||||
break
|
||||
}
|
||||
}
|
||||
if mostRecentRange == nil || !rangeUsesField(mostRecentRange, "Messages") {
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO(drifkin): to be more robust, check that it's in the action
|
||||
// part, not the `if`'s pipeline part. We do match on the nearest list
|
||||
// that starts and ends with text nodes, which makes this not strictly
|
||||
// necessary for our heuristic
|
||||
|
||||
// go up to the nearest ancestor that is a *parse.ListNode
|
||||
for i := len(ancestors) - 1; i >= 0; i-- {
|
||||
if l, ok := ancestors[i].(*parse.ListNode); ok {
|
||||
firstNode := l.Nodes[0]
|
||||
if t, ok := firstNode.(*parse.TextNode); ok {
|
||||
openingTag = strings.TrimSpace(t.String())
|
||||
}
|
||||
lastNode := l.Nodes[len(l.Nodes)-1]
|
||||
if t, ok := lastNode.(*parse.TextNode); ok {
|
||||
closingTag = strings.TrimSpace(t.String())
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
exitFn := func(n parse.Node) {
|
||||
ancestors = ancestors[:len(ancestors)-1]
|
||||
}
|
||||
|
||||
templateVisit(t.Root, enterFn, exitFn)
|
||||
|
||||
return openingTag, closingTag
|
||||
}
|
||||
|
||||
// checks to see if the given field name is present in the pipeline of the given range node
|
||||
func rangeUsesField(rangeNode *parse.RangeNode, field string) bool {
|
||||
found := false
|
||||
enterFn := func(n parse.Node) bool {
|
||||
switch x := n.(type) {
|
||||
case *parse.FieldNode:
|
||||
if x.Ident[0] == field {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
templateVisit(rangeNode.BranchNode.Pipe, enterFn, nil)
|
||||
return found
|
||||
}
|
286
server/thinking_test.go
Normal file
286
server/thinking_test.go
Normal file
@ -0,0 +1,286 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
func TestExtractThinking(t *testing.T) {
|
||||
tests := []struct {
|
||||
in, wantContent, wantThink string
|
||||
}{
|
||||
{
|
||||
in: "<think> internal </think> world",
|
||||
wantThink: "internal ",
|
||||
wantContent: "world",
|
||||
},
|
||||
{
|
||||
in: "<think>a</think><think>b</think>c",
|
||||
wantThink: "a",
|
||||
wantContent: "<think>b</think>c",
|
||||
},
|
||||
{
|
||||
in: "no think",
|
||||
wantThink: "",
|
||||
wantContent: "no think",
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
gotThinking, gotContent := extractThinking(tt.in)
|
||||
if gotContent != tt.wantContent || gotThinking != tt.wantThink {
|
||||
t.Errorf("case %d: got (%q,%q), want (%q,%q)", i, gotThinking, gotContent, tt.wantThink, tt.wantContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinkingStreaming(t *testing.T) {
|
||||
|
||||
type step struct {
|
||||
input string
|
||||
wantThinking string
|
||||
wantContent string
|
||||
wantStateAfter thinkingParseState
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
skip bool
|
||||
steps []step
|
||||
}{
|
||||
{
|
||||
desc: "content without a thinking tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: " abc",
|
||||
wantThinking: "",
|
||||
wantContent: " abc",
|
||||
wantStateAfter: thinkingParseState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "content before a thinking tag nerfs the thinking tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: " abc <think>def</think> ghi",
|
||||
wantThinking: "",
|
||||
wantContent: " abc <think>def</think> ghi",
|
||||
wantStateAfter: thinkingParseState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "building up a thinking tag partially",
|
||||
// skip: true,
|
||||
steps: []step{
|
||||
{
|
||||
input: " <th",
|
||||
wantThinking: "",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingParseState_LookingForOpening,
|
||||
},
|
||||
{
|
||||
input: "in",
|
||||
wantThinking: "",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingParseState_LookingForOpening,
|
||||
},
|
||||
{
|
||||
input: "k>a",
|
||||
wantThinking: "a",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingParseState_Thinking,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial closing tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>abc</th",
|
||||
wantThinking: "abc",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingParseState_Thinking,
|
||||
},
|
||||
{
|
||||
input: "ink>def",
|
||||
wantThinking: "",
|
||||
wantContent: "def",
|
||||
wantStateAfter: thinkingParseState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial closing tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>abc</th",
|
||||
wantThinking: "abc",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingParseState_Thinking,
|
||||
},
|
||||
{
|
||||
input: "ing>def",
|
||||
wantThinking: "</thing>def",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingParseState_Thinking,
|
||||
},
|
||||
{
|
||||
input: "ghi</thi",
|
||||
wantThinking: "ghi",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingParseState_Thinking,
|
||||
},
|
||||
{
|
||||
input: "nk>jkl",
|
||||
wantThinking: "",
|
||||
wantContent: "jkl",
|
||||
wantStateAfter: thinkingParseState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
parser := thinkingParser{
|
||||
openingTag: "<think>",
|
||||
closingTag: "</think>",
|
||||
}
|
||||
if c.skip {
|
||||
continue
|
||||
}
|
||||
for i, step := range c.steps {
|
||||
thinking, content := parser.addContent(step.input)
|
||||
if content != step.wantContent || thinking != step.wantThinking {
|
||||
t.Errorf("case %q (step %d): got (%q,%q), want (%q,%q)", c.desc, i, content, thinking, step.wantContent, step.wantThinking)
|
||||
}
|
||||
if parser.state != step.wantStateAfter {
|
||||
t.Errorf("case %q (step %d): got state %s, want %s", c.desc, i, parser.state.String(), step.wantStateAfter.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferThinkingTags(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
tmplString string
|
||||
wantOpeningTag string
|
||||
wantClosingTag string
|
||||
}{
|
||||
{
|
||||
desc: "basic",
|
||||
tmplString: `
|
||||
{{ if .Thinking}}
|
||||
/think
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{ if and $last .Thinking }}
|
||||
<think>{{ .Thinking }}</think>
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
`,
|
||||
wantOpeningTag: "<think>",
|
||||
wantClosingTag: "</think>",
|
||||
},
|
||||
{
|
||||
desc: "doubly nested range",
|
||||
tmplString: `
|
||||
{{ if .Thinking}}
|
||||
/think
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- range $j, $_ := .NotMessages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{ if and $last .Thinking }}
|
||||
<think>{{ .Thinking }}</think>
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
`,
|
||||
wantOpeningTag: "",
|
||||
wantClosingTag: "",
|
||||
},
|
||||
{
|
||||
desc: "whitespace is trimmed",
|
||||
tmplString: `
|
||||
{{ if .Thinking}}
|
||||
/think
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{ if and $last .Thinking }}
|
||||
Some text before {{ .Thinking }} Some text after
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
`,
|
||||
wantOpeningTag: "Some text before",
|
||||
wantClosingTag: "Some text after",
|
||||
},
|
||||
{
|
||||
desc: "qwen3",
|
||||
tmplString: `
|
||||
{{- if or .System .Tools .Thinking }}<|im_start|>system
|
||||
{{- if .System }}
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
{{- if .Tools }}
|
||||
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{{- range .Tools }}
|
||||
{"type": "function", "function": {{ .Function }}}
|
||||
{{- end }}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
{{- end }}
|
||||
{{- if .Thinking }}
|
||||
/think
|
||||
{{- else }}
|
||||
/no_think
|
||||
{{- end }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{- if eq .Role "user" }}<|im_start|>user
|
||||
{{ .Content }}<|im_end|>
|
||||
{{ else if eq .Role "assistant" }}<|im_start|>assistant
|
||||
{{ if and $last .Thinking }}
|
||||
<think>{{ .Thinking }}</think>
|
||||
{{ end }}
|
||||
{{ if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }}<tool_call>
|
||||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{ end }}</tool_call>
|
||||
{{- end }}{{ if not $last }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- else if eq .Role "tool" }}<|im_start|>user
|
||||
<tool_response>
|
||||
{{ .Content }}
|
||||
</tool_response><|im_end|>
|
||||
{{ end }}
|
||||
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
|
||||
{{ end }}
|
||||
{{- end }}
|
||||
`,
|
||||
wantOpeningTag: "<think>",
|
||||
wantClosingTag: "</think>",
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
tmpl := template.Must(template.New("test").Parse(c.tmplString))
|
||||
openingTag, closingTag := inferThinkingTags(tmpl)
|
||||
if openingTag != c.wantOpeningTag || closingTag != c.wantClosingTag {
|
||||
t.Errorf("case %q: got (%q,%q), want (%q,%q)", c.desc, openingTag, closingTag, c.wantOpeningTag, c.wantClosingTag)
|
||||
}
|
||||
}
|
||||
}
|
101
template/pretty.go
Normal file
101
template/pretty.go
Normal file
@ -0,0 +1,101 @@
|
||||
package template
|
||||
|
||||
import (
|
||||
"strings"
|
||||
texttmpl "text/template"
|
||||
"text/template/parse"
|
||||
)
|
||||
|
||||
// Format returns a human-readable representation of the template.
|
||||
// The formatting indents nested sections such as if/else blocks.
|
||||
func Format(src string) (string, error) {
|
||||
tmpl, err := texttmpl.New("pretty").Parse(src)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var sb strings.Builder
|
||||
printNodes(tmpl.Tree.Root, 0, &sb)
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func indent(sb *strings.Builder, level int) {
|
||||
for i := 0; i < level; i++ {
|
||||
sb.WriteString(" ")
|
||||
}
|
||||
}
|
||||
|
||||
func printNodes(list *parse.ListNode, level int, sb *strings.Builder) {
|
||||
if list == nil {
|
||||
return
|
||||
}
|
||||
for _, n := range list.Nodes {
|
||||
printNode(n, level, sb)
|
||||
}
|
||||
}
|
||||
|
||||
func printNode(n parse.Node, level int, sb *strings.Builder) {
|
||||
switch n := n.(type) {
|
||||
case *parse.TextNode:
|
||||
text := strings.TrimSpace(string(n.Text))
|
||||
if text == "" {
|
||||
return
|
||||
}
|
||||
indent(sb, level)
|
||||
sb.WriteString(text)
|
||||
sb.WriteByte('\n')
|
||||
case *parse.ActionNode:
|
||||
indent(sb, level)
|
||||
// sb.WriteString("ACTION {{ ")
|
||||
sb.WriteString(n.String())
|
||||
// sb.WriteString(" }}\n")
|
||||
sb.WriteByte('\n')
|
||||
case *parse.IfNode:
|
||||
indent(sb, level)
|
||||
sb.WriteString("{{ if ")
|
||||
sb.WriteString(n.Pipe.String())
|
||||
sb.WriteString(" }}\n")
|
||||
printNodes(n.List, level+1, sb)
|
||||
if n.ElseList != nil {
|
||||
indent(sb, level)
|
||||
sb.WriteString("{{ else }}\n")
|
||||
printNodes(n.ElseList, level+1, sb)
|
||||
}
|
||||
indent(sb, level)
|
||||
sb.WriteString("{{ end }}\n")
|
||||
case *parse.RangeNode:
|
||||
indent(sb, level)
|
||||
sb.WriteString("{{ range ")
|
||||
sb.WriteString(n.Pipe.String())
|
||||
sb.WriteString(" }}\n")
|
||||
printNodes(n.List, level+1, sb)
|
||||
if n.ElseList != nil {
|
||||
indent(sb, level)
|
||||
sb.WriteString("{{ else }}\n")
|
||||
printNodes(n.ElseList, level+1, sb)
|
||||
}
|
||||
indent(sb, level)
|
||||
sb.WriteString("{{ end }}\n")
|
||||
case *parse.WithNode:
|
||||
indent(sb, level)
|
||||
sb.WriteString("{{ with ")
|
||||
sb.WriteString(n.Pipe.String())
|
||||
sb.WriteString(" }}\n")
|
||||
printNodes(n.List, level+1, sb)
|
||||
if n.ElseList != nil {
|
||||
indent(sb, level)
|
||||
sb.WriteString("{{ else }}\n")
|
||||
printNodes(n.ElseList, level+1, sb)
|
||||
}
|
||||
indent(sb, level)
|
||||
sb.WriteString("{{ end }}\n")
|
||||
case *parse.TemplateNode:
|
||||
indent(sb, level)
|
||||
sb.WriteString("{{ template ")
|
||||
sb.WriteString(n.Name)
|
||||
sb.WriteString(" }}\n")
|
||||
default:
|
||||
indent(sb, level)
|
||||
sb.WriteString(n.String())
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
30
template/pretty_test.go
Normal file
30
template/pretty_test.go
Normal file
@ -0,0 +1,30 @@
|
||||
package template
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFormatIndentation(t *testing.T) {
|
||||
tmpl := "{{ if .Cond }}A{{ else }}B{{ end }}"
|
||||
out, err := Format(tmpl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expectedLines := []string{
|
||||
"{{ if .Cond }}",
|
||||
" A",
|
||||
"{{ else }}",
|
||||
" B",
|
||||
"{{ end }}",
|
||||
}
|
||||
got := strings.Split(strings.TrimSpace(out), "\n")
|
||||
if len(got) != len(expectedLines) {
|
||||
t.Fatalf("expected %d lines, got %d: %q", len(expectedLines), len(got), out)
|
||||
}
|
||||
for i, line := range expectedLines {
|
||||
if strings.TrimSpace(got[i]) != strings.TrimSpace(line) {
|
||||
t.Errorf("line %d = %q, want %q", i, got[i], line)
|
||||
}
|
||||
}
|
||||
}
|
@ -167,6 +167,10 @@ type Values struct {
|
||||
api.Tools
|
||||
Prompt string
|
||||
Suffix string
|
||||
Think bool
|
||||
// whether or not the user explicitly set the thinking flag (vs. it being
|
||||
// implicitly false). Templates can't see whether `Think` is nil
|
||||
IsThinkSet bool
|
||||
|
||||
// forceLegacy is a flag used to test compatibility with legacy templates
|
||||
forceLegacy bool
|
||||
@ -225,6 +229,8 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
"Prompt": v.Prompt,
|
||||
"Suffix": v.Suffix,
|
||||
"Response": "",
|
||||
"Think": v.Think,
|
||||
"IsThinkSet": v.IsThinkSet,
|
||||
})
|
||||
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
||||
return t.Template.Execute(w, map[string]any{
|
||||
@ -232,6 +238,8 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
"Messages": messages,
|
||||
"Tools": v.Tools,
|
||||
"Response": "",
|
||||
"Think": v.Think,
|
||||
"IsThinkSet": v.IsThinkSet,
|
||||
})
|
||||
}
|
||||
|
||||
@ -244,6 +252,8 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
"System": system,
|
||||
"Prompt": prompt,
|
||||
"Response": response,
|
||||
"Think": v.Think,
|
||||
"IsThinkSet": v.IsThinkSet,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -289,6 +299,8 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
"System": system,
|
||||
"Prompt": prompt,
|
||||
"Response": response,
|
||||
"Think": v.Think,
|
||||
"IsThinkSet": v.IsThinkSet,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ const (
|
||||
CapabilityInsert = Capability("insert")
|
||||
CapabilityVision = Capability("vision")
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
CapabilityThinking = Capability("thinking")
|
||||
)
|
||||
|
||||
func (c Capability) String() string {
|
||||
|
Loading…
x
Reference in New Issue
Block a user