diff --git a/api/types.go b/api/types.go index 87844c67c..040476874 100644 --- a/api/types.go +++ b/api/types.go @@ -84,6 +84,9 @@ type ChatRequest struct { // Model is the model name, as in [GenerateRequest]. Model string `json:"model"` + // Template overrides the model's default prompt template. + Template string `json:"template"` + // Messages is the messages of the chat - can be used to keep a chat memory. Messages []Message `json:"messages"` diff --git a/cmd/cmd.go b/cmd/cmd.go index c898c7db6..0231528e6 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -947,6 +947,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { req := &api.ChatRequest{ Model: opts.Model, + Template: opts.Template, Messages: opts.Messages, Format: opts.Format, Options: opts.Options, diff --git a/cmd/interactive.go b/cmd/interactive.go index 9214f2db5..d2ab39179 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -18,6 +18,7 @@ import ( "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/progress" "github.com/ollama/ollama/readline" + "github.com/ollama/ollama/template" "github.com/ollama/ollama/types/errtypes" ) @@ -205,9 +206,17 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { fmt.Println("Set system message.") sb.Reset() case MultilineTemplate: - opts.Template = sb.String() - fmt.Println("Set prompt template.") + mTemplate := sb.String() sb.Reset() + _, err := template.Parse(mTemplate) + if err != nil { + multiline = MultilineNone + scanner.Prompt.UseAlt = false + fmt.Println("The template is invalid.") + continue + } + opts.Template = mTemplate + fmt.Println("Set prompt template.") } multiline = MultilineNone @@ -369,9 +378,15 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { fmt.Println("Set system message.") sb.Reset() } else if args[1] == "template" { - opts.Template = sb.String() - fmt.Println("Set prompt template.") + mTemplate := sb.String() sb.Reset() + _, err := template.Parse(mTemplate) + if err != nil { + fmt.Println("The template is invalid.") + continue + } + opts.Template = mTemplate + fmt.Println("Set prompt template.") } sb.Reset() diff --git a/server/routes.go b/server/routes.go index 4059c7c52..52171e97a 100644 --- a/server/routes.go +++ b/server/routes.go @@ -71,7 +71,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options // scheduleRunner schedules a runner after validating inputs such as capabilities and model options. // It returns the allocated runner, model instance, and consolidated options if successful and error otherwise. -func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) { +func (s *Server) scheduleRunner(ctx context.Context, name string, mTemplate string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) { if name == "" { return nil, nil, nil, fmt.Errorf("model %w", errRequired) } @@ -81,6 +81,13 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil return nil, nil, nil, err } + if mTemplate != "" { + model.Template, err = template.Parse(mTemplate) + if err != nil { + return nil, nil, nil, err + } + } + if err := model.CheckCapabilities(caps...); err != nil { return nil, nil, nil, fmt.Errorf("%s %w", name, err) } @@ -120,7 +127,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { } caps := []Capability{CapabilityCompletion} - r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) + r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, "", caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)}) return @@ -256,7 +263,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) + r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, "", []Capability{}, req.Options, req.KeepAlive) if err != nil { handleScheduleError(c, req.Model, err) return @@ -1132,7 +1139,7 @@ func (s *Server) ChatHandler(c *gin.Context) { } caps := []Capability{CapabilityCompletion} - r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) + r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, req.Template, caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) return