Compare commits

...

11 Commits

Author SHA1 Message Date
ParthSareen
0ef4db0e24 Error handle detokenization model mismatch 2024-12-18 17:39:34 -08:00
ParthSareen
e3dd90102d WIP 2024-12-18 10:49:22 -08:00
ParthSareen
11acb85ff3 WIP 2024-12-17 16:54:47 -08:00
ParthSareen
1e545ea7a0 Add caching for model loading 2024-12-17 15:31:02 -08:00
ParthSareen
e679885733 WIP updated routes 2024-12-16 15:49:28 -08:00
ParthSareen
f0a5f7994b add tests 2024-12-16 11:06:14 -08:00
ParthSareen
da35ad878b update docs 2024-12-15 23:35:43 -08:00
ParthSareen
a5e66a1163 Better err handling 2024-12-15 23:14:18 -08:00
ParthSareen
6fad1637ed Cleanup 2024-12-14 20:30:09 -08:00
Yurzs
24613df094 docs: add tokenize and detokenize api 2024-12-14 17:53:24 -08:00
Yurzs
e60db349b7 api: expose tokenize and detokenize endpoints 2024-12-14 17:53:24 -08:00
9 changed files with 724 additions and 0 deletions

View File

@ -360,6 +360,24 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
return &resp, nil
}
// Tokenize returns the tokens for a given text.
func (c *Client) Tokenize(ctx context.Context, req *TokenizeRequest) (*TokenizeResponse, error) {
var resp TokenizeResponse
if err := c.do(ctx, http.MethodPost, "/api/tokenize", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// Detokenize returns the text for a given list of tokens.
func (c *Client) Detokenize(ctx context.Context, req *DetokenizeRequest) (*DetokenizeResponse, error) {
var resp DetokenizeResponse
if err := c.do(ctx, http.MethodPost, "/api/detokenize", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// CreateBlob creates a blob from a file on the server. digest is the
// expected SHA256 digest of the file, and r represents the file.
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {

View File

@ -293,6 +293,28 @@ type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
}
// TokenizeRequest is the request sent by [Client.Tokenize].
type TokenizeRequest struct {
Model string `json:"model"`
Text string `json:"text"`
}
// TokenizeResponse is the response from [Client.Tokenize].
type TokenizeResponse struct {
Tokens []int `json:"tokens"`
}
// DetokenizeRequest is the request sent by [Client.Detokenize].
type DetokenizeRequest struct {
Model string `json:"model"`
Tokens []int `json:"tokens"`
}
// DetokenizeResponse is the response from [Client.Detokenize].
type DetokenizeResponse struct {
Text string `json:"text"`
}
// CreateRequest is the request passed to [Client.Create].
type CreateRequest struct {
Model string `json:"model"`

View File

@ -13,6 +13,8 @@
- [Push a Model](#push-a-model)
- [Generate Embeddings](#generate-embeddings)
- [List Running Models](#list-running-models)
- [Tokenize Text](#tokenize-text)
- [Detokenize Tokens](#detokenize-tokens)
## Conventions
@ -1485,6 +1487,69 @@ A single JSON object will be returned.
}
```
## Tokenize Text
Tokenize text to an array of tokens using a specific model.
```shell
POST /api/tokenize
```
##### Parameters
- `model`: name of model to use for tokenization
- `text`: text to tokenize
### Examples
#### Request
```shell
curl -X POST http://localhost:11434/api/tokenize -d '{
"model": "llama3.2",
"text": "Why is the sky blue?"
}'
```
#### Response
```json
{
"tokens": [10445,279,13180,374,6437,30]
}
```
## Detokenize Tokens
Detokenize tokens to text using a specific model.
```shell
POST /api/detokenize
```
#### Parameters
- `model`: name of model to use for detokenization
- `tokens`: list of tokens to detokenize
### Examples
#### Request
```shell
curl -X POST http://localhost:11434/api/detokenize -d '{
"model": "llama3.2",
"tokens": [10445,374,279,13180,6437,30]
}'
```
#### Response
```json
{"text":"Why is the sky blue?"}
```
## Generate Embedding
> Note: this endpoint has been superseded by `/api/embed`

View File

@ -449,9 +449,24 @@ type Model struct {
c *C.struct_llama_model
}
func (m *Model) Detokenize(tokens []int) (string, error) {
var text string
for _, token := range tokens {
piece := m.TokenToPiece(token)
if piece == "" {
return "", fmt.Errorf("failed to convert token %d to piece", token)
}
text += piece
}
return text, nil
}
func (m *Model) TokenToPiece(token int) string {
tokenLen := 12
buf := make([]byte, tokenLen)
if token > m.NumVocab() {
return ""
}
tokenLen = int(C.llama_token_to_piece(
m.c,
C.int32_t(token),

67
server/model_loader.go Normal file
View File

@ -0,0 +1,67 @@
package server
import (
"fmt"
"sync"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/types/model"
)
type loadedModel struct {
model llama.Model
modelPath string
}
type modelLoader struct {
cache sync.Map
}
// modelCache stores loaded models keyed by their full path and params hash
var modelCache sync.Map // map[string]*loadedModel
func (ml *modelLoader) LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
modelName := model.ParseName(name)
if !modelName.IsValid() {
return nil, fmt.Errorf("invalid model name: %s", modelName)
}
modelPath, err := GetModel(modelName.String())
if err != nil {
return nil, fmt.Errorf("model not found: %s", modelName)
}
// Create cache key from model path and params hash
cacheKey := fmt.Sprintf("%s-%+v", modelPath.ModelPath, params)
if cached, ok := modelCache.Load(cacheKey); ok {
return cached.(*loadedModel), nil
}
// Evict existing model if any
ml.evictExistingModel()
model, err := llama.LoadModelFromFile(modelPath.ModelPath, params)
if err != nil {
return nil, fmt.Errorf("failed to load model: %v", err)
}
loaded := &loadedModel{
model: *model,
modelPath: modelPath.ModelPath,
}
modelCache.Store(cacheKey, loaded)
return loaded, nil
}
// evictExistingModel removes any currently loaded model from the cache
// Currently only supports a single model in cache at a time
// TODO: Add proper cache eviction policy (LRU/size/TTL based)
func (ml *modelLoader) evictExistingModel() {
ml.cache.Range(func(key, value any) bool {
if cached, ok := ml.cache.LoadAndDelete(key); ok {
llama.FreeModel(&cached.(*loadedModel).model)
}
return true
})
}

View File

@ -30,6 +30,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser"
@ -46,6 +47,7 @@ var mode string = gin.DebugMode
type Server struct {
addr net.Addr
sched *Scheduler
ml modelLoader
}
func init() {
@ -548,6 +550,105 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp)
}
func (s *Server) TokenizeHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req api.TokenizeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
if errors.Is(err, io.EOF) {
http.Error(w, "missing request body", http.StatusBadRequest)
return
}
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if req.Text == "" {
http.Error(w, "missing `text` for tokenization", http.StatusBadRequest)
return
}
if req.Model == "" {
http.Error(w, "missing `model` for tokenization", http.StatusBadRequest)
return
}
loadedModel, err := s.ml.LoadModel(req.Model, llama.ModelParams{
VocabOnly: true,
})
if err != nil {
http.Error(w, fmt.Sprintf("failed to load model: %v", err), http.StatusInternalServerError)
return
}
// Tokenize the text
tokens, err := loadedModel.model.Tokenize(req.Text, false, true)
if err != nil {
http.Error(w, fmt.Sprintf("failed to tokenize text: %v", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(api.TokenizeResponse{
Tokens: tokens,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
return
}
}
func (s *Server) DetokenizeHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req api.DetokenizeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
if errors.Is(err, io.EOF) {
http.Error(w, "missing request body", http.StatusBadRequest)
return
}
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if req.Tokens == nil {
http.Error(w, "missing tokens for detokenization", http.StatusBadRequest)
return
}
if req.Model == "" {
http.Error(w, "missing `model` for detokenization", http.StatusBadRequest)
return
}
loadedModel, err := s.ml.LoadModel(req.Model, llama.ModelParams{
VocabOnly: true,
})
if err != nil {
http.Error(w, fmt.Sprintf("failed to load model: %v", err), http.StatusInternalServerError)
return
}
text, err := loadedModel.model.Detokenize(req.Tokens)
if err != nil {
http.Error(w, fmt.Sprintf("failed to detokenize text: %v", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(api.DetokenizeResponse{
Text: text,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
return
}
}
func (s *Server) PullHandler(c *gin.Context) {
var req api.PullRequest
err := c.ShouldBindJSON(&req)
@ -1214,6 +1315,8 @@ func (s *Server) GenerateRoutes() http.Handler {
r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler)
r.Any("/api/tokenize", gin.WrapF(s.TokenizeHandler))
r.Any("/api/detokenize", gin.WrapF(s.DetokenizeHandler))
r.POST("/api/create", s.CreateHandler)
r.POST("/api/push", s.PushHandler)
r.POST("/api/copy", s.CopyHandler)

View File

@ -46,6 +46,14 @@ func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error
return
}
func (mockRunner) Detokenize(_ context.Context, tokens []int) (string, error) {
var strs []string
for _, t := range tokens {
strs = append(strs, fmt.Sprint(t))
}
return strings.Join(strs, " "), nil
}
func newMockServer(mock *mockRunner) func(discover.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
return func(gpus discover.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
return mock, nil

View File

@ -0,0 +1,290 @@
package server
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/llm"
)
func TestTokenize(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: "stop",
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
},
},
}
go s.sched.Run(context.TODO())
t.Run("missing body", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/tokenize", nil)
s.TokenizeHandler(w, r)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), "missing request body\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing model", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/tokenize", strings.NewReader("{}"))
s.TokenizeHandler(w, r)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), "missing `text` for tokenization\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("tokenize text", func(t *testing.T) {
// First create the model
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test",
Modelfile: fmt.Sprintf(`FROM %s`, createBinFile(t, llm.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})),
})
if w.Code != http.StatusOK {
t.Fatalf("failed to create model: %d", w.Code)
}
// Now test tokenization
body, err := json.Marshal(api.TokenizeRequest{
Model: "test",
Text: "Hello world how are you",
})
if err != nil {
t.Fatalf("failed to marshal request: %v", err)
}
w = httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/tokenize", bytes.NewReader(body))
r.Header.Set("Content-Type", "application/json")
s.TokenizeHandler(w, r)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
var resp api.TokenizeResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Errorf("failed to decode response: %v", err)
}
// Our mock tokenizer creates sequential tokens based on word count
expected := []int{0, 1, 2, 3, 4}
if diff := cmp.Diff(resp.Tokens, expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("tokenize empty text", func(t *testing.T) {
body, err := json.Marshal(api.TokenizeRequest{
Model: "test",
Text: "",
})
if err != nil {
t.Fatalf("failed to marshal request: %v", err)
}
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/tokenize", bytes.NewReader(body))
r.Header.Set("Content-Type", "application/json")
s.TokenizeHandler(w, r)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), "missing `text` for tokenization\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
func TestDetokenize(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: "stop",
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
},
},
}
go s.sched.Run(context.TODO())
t.Run("detokenize tokens", func(t *testing.T) {
// Create the model first
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test",
Modelfile: fmt.Sprintf(`FROM %s`, createBinFile(t, llm.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("failed to create model: %d - %s", w.Code, w.Body.String())
}
body, err := json.Marshal(api.DetokenizeRequest{
Model: "test",
Tokens: []int{0, 1, 2, 3, 4},
})
if err != nil {
t.Fatalf("failed to marshal request: %v", err)
}
w = httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
r.Header.Set("Content-Type", "application/json")
s.DetokenizeHandler(w, r)
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
}
var resp api.DetokenizeResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Errorf("failed to decode response: %v", err)
}
})
t.Run("detokenize empty tokens", func(t *testing.T) {
body, err := json.Marshal(api.DetokenizeRequest{
Model: "test",
})
if err != nil {
t.Fatalf("failed to marshal request: %v", err)
}
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
r.Header.Set("Content-Type", "application/json")
s.DetokenizeHandler(w, r)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), "missing tokens for detokenization\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("detokenize missing model", func(t *testing.T) {
body, err := json.Marshal(api.DetokenizeRequest{
Tokens: []int{0, 1, 2},
})
if err != nil {
t.Fatalf("failed to marshal request: %v", err)
}
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
r.Header.Set("Content-Type", "application/json")
s.DetokenizeHandler(w, r)
if w.Code != http.StatusNotFound {
t.Errorf("expected status 404, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), "model '' not found\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}

View File

@ -0,0 +1,136 @@
package server
import (
"encoding/json"
"fmt"
"net/http"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/llm"
)
type mockModelLoader struct {
LoadModelFn func(string, llama.ModelParams) (*loadedModel, error)
}
func (ml *mockModelLoader) LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
if ml.LoadModelFn != nil {
return ml.LoadModelFn(name, params)
}
return nil, nil
}
type mockModel struct {
llama.Model
TokenizeFn func(text string, addBos bool, addEos bool) ([]int, error)
TokenToPieceFn func(token int) string
}
func (mockModel) Tokenize(text string, addBos bool, addEos bool) ([]int, error) {
return []int{1, 2, 3}, nil
}
func (mockModel) TokenToPiece(token int) string {
return fmt.Sprint(token)
}
func TestTokenizeHandler(t *testing.T) {
gin.SetMode(gin.TestMode)
mockModel := mockModel{}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mockRunner{}),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mockRunner{},
}
},
},
ml: mockLoader,
}
t.Run("method not allowed", func(t *testing.T) {
w := createRequest(t, gin.WrapF(s.TokenizeHandler), nil)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
}
})
t.Run("missing body", func(t *testing.T) {
w := createRequest(t, gin.WrapF(s.TokenizeHandler), nil)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
})
t.Run("missing text", func(t *testing.T) {
w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
Model: "test",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
})
t.Run("missing model", func(t *testing.T) {
w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
Text: "test text",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
}
})
t.Run("model not found", func(t *testing.T) {
w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
Model: "nonexistent",
Text: "test text",
})
if w.Code != http.StatusInternalServerError {
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
}
})
t.Run("successful tokenization", func(t *testing.T) {
w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
Model: "test",
Text: "test text",
})
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp api.TokenizeResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
expectedTokens := []int{0, 1}
if len(resp.Tokens) != len(expectedTokens) {
t.Errorf("expected %d tokens, got %d", len(expectedTokens), len(resp.Tokens))
}
for i, token := range resp.Tokens {
if token != expectedTokens[i] {
t.Errorf("expected token %d at position %d, got %d", expectedTokens[i], i, token)
}
}
})
}