This commit is contained in:
ParthSareen 2024-12-14 20:30:09 -08:00
parent 24613df094
commit 6fad1637ed
3 changed files with 50 additions and 40 deletions

View File

@ -360,7 +360,7 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
return &resp, nil return &resp, nil
} }
// Tokenize tokenizes a string. // Tokenize returns the tokens for a given text.
func (c *Client) Tokenize(ctx context.Context, req *TokenizeRequest) (*TokenizeResponse, error) { func (c *Client) Tokenize(ctx context.Context, req *TokenizeRequest) (*TokenizeResponse, error) {
var resp TokenizeResponse var resp TokenizeResponse
if err := c.do(ctx, http.MethodPost, "/api/tokenize", req, &resp); err != nil { if err := c.do(ctx, http.MethodPost, "/api/tokenize", req, &resp); err != nil {
@ -369,7 +369,7 @@ func (c *Client) Tokenize(ctx context.Context, req *TokenizeRequest) (*TokenizeR
return &resp, nil return &resp, nil
} }
// Detokenize detokenizes a string. // Detokenize returns the text for a given list of tokens.
func (c *Client) Detokenize(ctx context.Context, req *DetokenizeRequest) (*DetokenizeResponse, error) { func (c *Client) Detokenize(ctx context.Context, req *DetokenizeRequest) (*DetokenizeResponse, error) {
var resp DetokenizeResponse var resp DetokenizeResponse
if err := c.do(ctx, http.MethodPost, "/api/detokenize", req, &resp); err != nil { if err := c.do(ctx, http.MethodPost, "/api/detokenize", req, &resp); err != nil {

View File

@ -293,26 +293,22 @@ type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"` Embedding []float64 `json:"embedding"`
} }
// TokenizeRequest is the request passed to [Client.Tokenize]. // TokenizeRequest is the request sent by [Client.Tokenize].
type TokenizeRequest struct { type TokenizeRequest struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Text string `json:"text"`
// KeepAlive controls how long the model will stay loaded in memory following // KeepAlive controls how long the model will stay loaded in memory following
// this request. // this request.
KeepAlive *Duration `json:"keep_alive,omitempty"` KeepAlive *Duration `json:"keep_alive,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
} }
// TokenizeResponse is the response from [Client.Tokenize]. // TokenizeResponse is the response from [Client.Tokenize].
type TokenizeResponse struct { type TokenizeResponse struct {
Model string `json:"model"`
Tokens []int `json:"tokens"` Tokens []int `json:"tokens"`
} }
// DetokenizeRequest is the request passed to [Client.Detokenize]. // DetokenizeRequest is the request sent by [Client.Detokenize].
type DetokenizeRequest struct { type DetokenizeRequest struct {
Model string `json:"model"` Model string `json:"model"`
Tokens []int `json:"tokens"` Tokens []int `json:"tokens"`
@ -320,14 +316,10 @@ type DetokenizeRequest struct {
// KeepAlive controls how long the model will stay loaded in memory following // KeepAlive controls how long the model will stay loaded in memory following
// this request. // this request.
KeepAlive *Duration `json:"keep_alive,omitempty"` KeepAlive *Duration `json:"keep_alive,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
} }
// DetokenizeResponse is the response from [Client.Detokenize]. // DetokenizeResponse is the response from [Client.Detokenize].
type DetokenizeResponse struct { type DetokenizeResponse struct {
Model string `json:"model"`
Text string `json:"text"` Text string `json:"text"`
} }

View File

@ -548,54 +548,72 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)
} }
func (s *Server) TokenizeHandler(c *gin.Context) { func (s *Server) TokenizeHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req api.TokenizeRequest var req api.TokenizeRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) if errors.Is(err, io.EOF) {
http.Error(w, "missing request body", http.StatusBadRequest)
return return
} else if err != nil { }
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) runner, _, _, err := s.scheduleRunner(r.Context(), req.Model, []Capability{}, nil, req.KeepAlive)
if err != nil { if err != nil {
handleScheduleError(c, req.Model, err) http.Error(w, fmt.Sprintf("model '%s' not found", req.Model), http.StatusNotFound)
return return
} }
tokens, err := r.Tokenize(c.Request.Context(), req.Prompt) tokens, err := runner.Tokenize(r.Context(), req.Text)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
c.JSON(http.StatusOK, api.TokenizeResponse{Model: req.Model, Tokens: tokens}) w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(api.TokenizeResponse{
Tokens: tokens,
})
}
func (s *Server) DetokenizeHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
} }
func (s *Server) DetokenizeHandler(c *gin.Context) {
var req api.DetokenizeRequest var req api.DetokenizeRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) if errors.Is(err, io.EOF) {
http.Error(w, "missing request body", http.StatusBadRequest)
return return
} else if err != nil { }
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) runner, _, _, err := s.scheduleRunner(r.Context(), req.Model, []Capability{}, nil, req.KeepAlive)
if err != nil { if err != nil {
handleScheduleError(c, req.Model, err) http.Error(w, fmt.Sprintf("model '%s' not found", req.Model), http.StatusNotFound)
return return
} }
text, err := r.Detokenize(c.Request.Context(), req.Tokens) text, err := runner.Detokenize(r.Context(), req.Tokens)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
c.JSON(http.StatusOK, api.DetokenizeResponse{Model: req.Model, Text: text}) w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(api.DetokenizeResponse{
Text: text,
})
} }
func (s *Server) PullHandler(c *gin.Context) { func (s *Server) PullHandler(c *gin.Context) {
@ -1264,8 +1282,8 @@ func (s *Server) GenerateRoutes() http.Handler {
r.POST("/api/chat", s.ChatHandler) r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler) r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler) r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/tokenize", s.TokenizeHandler) r.Any("/api/tokenize", gin.WrapF(s.TokenizeHandler))
r.POST("/api/detokenize", s.DetokenizeHandler) r.Any("/api/detokenize", gin.WrapF(s.DetokenizeHandler))
r.POST("/api/create", s.CreateHandler) r.POST("/api/create", s.CreateHandler)
r.POST("/api/push", s.PushHandler) r.POST("/api/push", s.PushHandler)
r.POST("/api/copy", s.CopyHandler) r.POST("/api/copy", s.CopyHandler)