Add /template endpoint
This commit is contained in:
parent
a72f2dce45
commit
1d529d8b7b
@ -360,6 +360,14 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
|
|||||||
return &resp, nil
|
return &resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) Template(ctx context.Context, req *TemplateRequest) (*TemplateResponse, error) {
|
||||||
|
var resp TemplateResponse
|
||||||
|
if err := c.do(ctx, http.MethodPost, "/api/template", req, &resp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
// CreateBlob creates a blob from a file on the server. digest is the
|
// CreateBlob creates a blob from a file on the server. digest is the
|
||||||
// expected SHA256 digest of the file, and r represents the file.
|
// expected SHA256 digest of the file, and r represents the file.
|
||||||
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
|
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
|
||||||
|
10
api/types.go
10
api/types.go
@ -310,6 +310,16 @@ type CreateRequest struct {
|
|||||||
Quantization string `json:"quantization,omitempty"`
|
Quantization string `json:"quantization,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TemplateRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []Message `json:"messages"`
|
||||||
|
Tools []Tool `json:"tools"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TemplateResponse struct {
|
||||||
|
TemplatedPrompt string `json:"templated_prompt"`
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteRequest is the request passed to [Client.Delete].
|
// DeleteRequest is the request passed to [Client.Delete].
|
||||||
type DeleteRequest struct {
|
type DeleteRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
@ -142,6 +142,21 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
return b.String(), images, nil
|
return b.String(), images, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func applyTemplate(m *Model, msgs []api.Message, tools []api.Tool) (string, error) {
|
||||||
|
isMllama := checkMllamaModelFamily(m)
|
||||||
|
for _, msg := range msgs {
|
||||||
|
if isMllama && len(msg.Images) > 1 {
|
||||||
|
return "", errTooManyImages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := m.Template.Execute(&b, template.Values{Messages: msgs, Tools: tools}); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return b.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
func checkMllamaModelFamily(m *Model) bool {
|
func checkMllamaModelFamily(m *Model) bool {
|
||||||
for _, arch := range m.Config.ModelFamilies {
|
for _, arch := range m.Config.ModelFamilies {
|
||||||
if arch == "mllama" {
|
if arch == "mllama" {
|
||||||
|
@ -1228,6 +1228,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
|||||||
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
||||||
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
|
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
|
||||||
r.GET("/api/ps", s.PsHandler)
|
r.GET("/api/ps", s.PsHandler)
|
||||||
|
r.Any("/api/template", gin.WrapF(s.TemplateHandler))
|
||||||
|
|
||||||
// Compatibility endpoints
|
// Compatibility endpoints
|
||||||
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
||||||
@ -1451,6 +1452,38 @@ func (s *Server) PsHandler(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) TemplateHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req api.TemplateRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
model, err := GetModel(req.Model)
|
||||||
|
if err != nil {
|
||||||
|
switch {
|
||||||
|
case os.IsNotExist(err):
|
||||||
|
http.Error(w, fmt.Sprintf("model '%s' not found", req.Model), http.StatusNotFound)
|
||||||
|
case err.Error() == "invalid model name":
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
default:
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt, err := applyTemplate(model, req.Messages, req.Tools)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewEncoder(w).Encode(api.TemplateResponse{TemplatedPrompt: prompt}); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) ChatHandler(c *gin.Context) {
|
func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
checkpointStart := time.Now()
|
checkpointStart := time.Now()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user