Compare commits
11 Commits
main
...
parth/toke
Author | SHA1 | Date | |
---|---|---|---|
![]() |
0ef4db0e24 | ||
![]() |
e3dd90102d | ||
![]() |
11acb85ff3 | ||
![]() |
1e545ea7a0 | ||
![]() |
e679885733 | ||
![]() |
f0a5f7994b | ||
![]() |
da35ad878b | ||
![]() |
a5e66a1163 | ||
![]() |
6fad1637ed | ||
![]() |
24613df094 | ||
![]() |
e60db349b7 |
@ -360,6 +360,24 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// Tokenize returns the tokens for a given text.
|
||||
func (c *Client) Tokenize(ctx context.Context, req *TokenizeRequest) (*TokenizeResponse, error) {
|
||||
var resp TokenizeResponse
|
||||
if err := c.do(ctx, http.MethodPost, "/api/tokenize", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// Detokenize returns the text for a given list of tokens.
|
||||
func (c *Client) Detokenize(ctx context.Context, req *DetokenizeRequest) (*DetokenizeResponse, error) {
|
||||
var resp DetokenizeResponse
|
||||
if err := c.do(ctx, http.MethodPost, "/api/detokenize", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// CreateBlob creates a blob from a file on the server. digest is the
|
||||
// expected SHA256 digest of the file, and r represents the file.
|
||||
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
|
||||
|
22
api/types.go
22
api/types.go
@ -293,6 +293,28 @@ type EmbeddingResponse struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
// TokenizeRequest is the request sent by [Client.Tokenize].
|
||||
type TokenizeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// TokenizeResponse is the response from [Client.Tokenize].
|
||||
type TokenizeResponse struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
|
||||
// DetokenizeRequest is the request sent by [Client.Detokenize].
|
||||
type DetokenizeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
|
||||
// DetokenizeResponse is the response from [Client.Detokenize].
|
||||
type DetokenizeResponse struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// CreateRequest is the request passed to [Client.Create].
|
||||
type CreateRequest struct {
|
||||
Model string `json:"model"`
|
||||
|
65
docs/api.md
65
docs/api.md
@ -13,6 +13,8 @@
|
||||
- [Push a Model](#push-a-model)
|
||||
- [Generate Embeddings](#generate-embeddings)
|
||||
- [List Running Models](#list-running-models)
|
||||
- [Tokenize Text](#tokenize-text)
|
||||
- [Detokenize Tokens](#detokenize-tokens)
|
||||
|
||||
## Conventions
|
||||
|
||||
@ -1485,6 +1487,69 @@ A single JSON object will be returned.
|
||||
}
|
||||
```
|
||||
|
||||
## Tokenize Text
|
||||
|
||||
Tokenize text to an array of tokens using a specific model.
|
||||
|
||||
```shell
|
||||
POST /api/tokenize
|
||||
```
|
||||
|
||||
##### Parameters
|
||||
|
||||
- `model`: name of model to use for tokenization
|
||||
- `text`: text to tokenize
|
||||
|
||||
### Examples
|
||||
|
||||
#### Request
|
||||
|
||||
```shell
|
||||
curl -X POST http://localhost:11434/api/tokenize -d '{
|
||||
"model": "llama3.2",
|
||||
"text": "Why is the sky blue?"
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"tokens": [10445,279,13180,374,6437,30]
|
||||
}
|
||||
```
|
||||
|
||||
## Detokenize Tokens
|
||||
|
||||
Detokenize tokens to text using a specific model.
|
||||
|
||||
```shell
|
||||
POST /api/detokenize
|
||||
```
|
||||
|
||||
#### Parameters
|
||||
|
||||
- `model`: name of model to use for detokenization
|
||||
- `tokens`: list of tokens to detokenize
|
||||
|
||||
### Examples
|
||||
|
||||
#### Request
|
||||
|
||||
```shell
|
||||
curl -X POST http://localhost:11434/api/detokenize -d '{
|
||||
"model": "llama3.2",
|
||||
"tokens": [10445,374,279,13180,6437,30]
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
```json
|
||||
{"text":"Why is the sky blue?"}
|
||||
```
|
||||
|
||||
|
||||
## Generate Embedding
|
||||
|
||||
> Note: this endpoint has been superseded by `/api/embed`
|
||||
|
@ -449,9 +449,24 @@ type Model struct {
|
||||
c *C.struct_llama_model
|
||||
}
|
||||
|
||||
func (m *Model) Detokenize(tokens []int) (string, error) {
|
||||
var text string
|
||||
for _, token := range tokens {
|
||||
piece := m.TokenToPiece(token)
|
||||
if piece == "" {
|
||||
return "", fmt.Errorf("failed to convert token %d to piece", token)
|
||||
}
|
||||
text += piece
|
||||
}
|
||||
return text, nil
|
||||
}
|
||||
|
||||
func (m *Model) TokenToPiece(token int) string {
|
||||
tokenLen := 12
|
||||
buf := make([]byte, tokenLen)
|
||||
if token > m.NumVocab() {
|
||||
return ""
|
||||
}
|
||||
tokenLen = int(C.llama_token_to_piece(
|
||||
m.c,
|
||||
C.int32_t(token),
|
||||
|
67
server/model_loader.go
Normal file
67
server/model_loader.go
Normal file
@ -0,0 +1,67 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type loadedModel struct {
|
||||
model llama.Model
|
||||
modelPath string
|
||||
}
|
||||
|
||||
type modelLoader struct {
|
||||
cache sync.Map
|
||||
}
|
||||
|
||||
// modelCache stores loaded models keyed by their full path and params hash
|
||||
var modelCache sync.Map // map[string]*loadedModel
|
||||
|
||||
func (ml *modelLoader) LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
|
||||
modelName := model.ParseName(name)
|
||||
if !modelName.IsValid() {
|
||||
return nil, fmt.Errorf("invalid model name: %s", modelName)
|
||||
}
|
||||
|
||||
modelPath, err := GetModel(modelName.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model not found: %s", modelName)
|
||||
}
|
||||
|
||||
// Create cache key from model path and params hash
|
||||
cacheKey := fmt.Sprintf("%s-%+v", modelPath.ModelPath, params)
|
||||
if cached, ok := modelCache.Load(cacheKey); ok {
|
||||
return cached.(*loadedModel), nil
|
||||
}
|
||||
|
||||
// Evict existing model if any
|
||||
ml.evictExistingModel()
|
||||
|
||||
model, err := llama.LoadModelFromFile(modelPath.ModelPath, params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load model: %v", err)
|
||||
}
|
||||
|
||||
loaded := &loadedModel{
|
||||
model: *model,
|
||||
modelPath: modelPath.ModelPath,
|
||||
}
|
||||
modelCache.Store(cacheKey, loaded)
|
||||
|
||||
return loaded, nil
|
||||
}
|
||||
|
||||
// evictExistingModel removes any currently loaded model from the cache
|
||||
// Currently only supports a single model in cache at a time
|
||||
// TODO: Add proper cache eviction policy (LRU/size/TTL based)
|
||||
func (ml *modelLoader) evictExistingModel() {
|
||||
ml.cache.Range(func(key, value any) bool {
|
||||
if cached, ok := ml.cache.LoadAndDelete(key); ok {
|
||||
llama.FreeModel(&cached.(*loadedModel).model)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
103
server/routes.go
103
server/routes.go
@ -30,6 +30,7 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/parser"
|
||||
@ -46,6 +47,7 @@ var mode string = gin.DebugMode
|
||||
type Server struct {
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
ml modelLoader
|
||||
}
|
||||
|
||||
func init() {
|
||||
@ -548,6 +550,105 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (s *Server) TokenizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.TokenizeRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
http.Error(w, "missing request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Text == "" {
|
||||
http.Error(w, "missing `text` for tokenization", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
http.Error(w, "missing `model` for tokenization", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
loadedModel, err := s.ml.LoadModel(req.Model, llama.ModelParams{
|
||||
VocabOnly: true,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to load model: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Tokenize the text
|
||||
tokens, err := loadedModel.model.Tokenize(req.Text, false, true)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to tokenize text: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(api.TokenizeResponse{
|
||||
Tokens: tokens,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) DetokenizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.DetokenizeRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
http.Error(w, "missing request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Tokens == nil {
|
||||
http.Error(w, "missing tokens for detokenization", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
http.Error(w, "missing `model` for detokenization", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
loadedModel, err := s.ml.LoadModel(req.Model, llama.ModelParams{
|
||||
VocabOnly: true,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to load model: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
text, err := loadedModel.model.Detokenize(req.Tokens)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to detokenize text: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(api.DetokenizeResponse{
|
||||
Text: text,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) PullHandler(c *gin.Context) {
|
||||
var req api.PullRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
@ -1214,6 +1315,8 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
r.POST("/api/chat", s.ChatHandler)
|
||||
r.POST("/api/embed", s.EmbedHandler)
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||
r.Any("/api/tokenize", gin.WrapF(s.TokenizeHandler))
|
||||
r.Any("/api/detokenize", gin.WrapF(s.DetokenizeHandler))
|
||||
r.POST("/api/create", s.CreateHandler)
|
||||
r.POST("/api/push", s.PushHandler)
|
||||
r.POST("/api/copy", s.CopyHandler)
|
||||
|
@ -46,6 +46,14 @@ func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error
|
||||
return
|
||||
}
|
||||
|
||||
func (mockRunner) Detokenize(_ context.Context, tokens []int) (string, error) {
|
||||
var strs []string
|
||||
for _, t := range tokens {
|
||||
strs = append(strs, fmt.Sprint(t))
|
||||
}
|
||||
return strings.Join(strs, " "), nil
|
||||
}
|
||||
|
||||
func newMockServer(mock *mockRunner) func(discover.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
|
||||
return func(gpus discover.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
return mock, nil
|
||||
|
290
server/routes_tokenization_test.go
Normal file
290
server/routes_tokenization_test.go
Normal file
@ -0,0 +1,290 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
func TestTokenize(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(context.TODO())
|
||||
|
||||
t.Run("missing body", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodPost, "/api/tokenize", nil)
|
||||
s.TokenizeHandler(w, r)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), "missing request body\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing model", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodPost, "/api/tokenize", strings.NewReader("{}"))
|
||||
s.TokenizeHandler(w, r)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), "missing `text` for tokenization\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
t.Run("tokenize text", func(t *testing.T) {
|
||||
// First create the model
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test",
|
||||
Modelfile: fmt.Sprintf(`FROM %s`, createBinFile(t, llm.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []llm.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})),
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("failed to create model: %d", w.Code)
|
||||
}
|
||||
|
||||
// Now test tokenization
|
||||
body, err := json.Marshal(api.TokenizeRequest{
|
||||
Model: "test",
|
||||
Text: "Hello world how are you",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal request: %v", err)
|
||||
}
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodPost, "/api/tokenize", bytes.NewReader(body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
s.TokenizeHandler(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp api.TokenizeResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Errorf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
// Our mock tokenizer creates sequential tokens based on word count
|
||||
expected := []int{0, 1, 2, 3, 4}
|
||||
if diff := cmp.Diff(resp.Tokens, expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tokenize empty text", func(t *testing.T) {
|
||||
body, err := json.Marshal(api.TokenizeRequest{
|
||||
Model: "test",
|
||||
Text: "",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal request: %v", err)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodPost, "/api/tokenize", bytes.NewReader(body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
s.TokenizeHandler(w, r)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), "missing `text` for tokenization\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDetokenize(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: "stop",
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(context.TODO())
|
||||
|
||||
t.Run("detokenize tokens", func(t *testing.T) {
|
||||
// Create the model first
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test",
|
||||
Modelfile: fmt.Sprintf(`FROM %s`, createBinFile(t, llm.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []llm.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})),
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("failed to create model: %d - %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
body, err := json.Marshal(api.DetokenizeRequest{
|
||||
Model: "test",
|
||||
Tokens: []int{0, 1, 2, 3, 4},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal request: %v", err)
|
||||
}
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
s.DetokenizeHandler(w, r)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp api.DetokenizeResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Errorf("failed to decode response: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("detokenize empty tokens", func(t *testing.T) {
|
||||
body, err := json.Marshal(api.DetokenizeRequest{
|
||||
Model: "test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal request: %v", err)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
s.DetokenizeHandler(w, r)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), "missing tokens for detokenization\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("detokenize missing model", func(t *testing.T) {
|
||||
body, err := json.Marshal(api.DetokenizeRequest{
|
||||
Tokens: []int{0, 1, 2},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal request: %v", err)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
s.DetokenizeHandler(w, r)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected status 404, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), "model '' not found\n"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
136
server/routes_tokenize_test.go
Normal file
136
server/routes_tokenize_test.go
Normal file
@ -0,0 +1,136 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type mockModelLoader struct {
|
||||
LoadModelFn func(string, llama.ModelParams) (*loadedModel, error)
|
||||
}
|
||||
|
||||
func (ml *mockModelLoader) LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
|
||||
if ml.LoadModelFn != nil {
|
||||
return ml.LoadModelFn(name, params)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type mockModel struct {
|
||||
llama.Model
|
||||
TokenizeFn func(text string, addBos bool, addEos bool) ([]int, error)
|
||||
TokenToPieceFn func(token int) string
|
||||
}
|
||||
|
||||
func (mockModel) Tokenize(text string, addBos bool, addEos bool) ([]int, error) {
|
||||
return []int{1, 2, 3}, nil
|
||||
}
|
||||
|
||||
func (mockModel) TokenToPiece(token int) string {
|
||||
return fmt.Sprint(token)
|
||||
}
|
||||
|
||||
func TestTokenizeHandler(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mockModel := mockModel{}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mockRunner{}),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mockRunner{},
|
||||
}
|
||||
},
|
||||
},
|
||||
ml: mockLoader,
|
||||
}
|
||||
|
||||
t.Run("method not allowed", func(t *testing.T) {
|
||||
w := createRequest(t, gin.WrapF(s.TokenizeHandler), nil)
|
||||
if w.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing body", func(t *testing.T) {
|
||||
w := createRequest(t, gin.WrapF(s.TokenizeHandler), nil)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing text", func(t *testing.T) {
|
||||
w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
|
||||
Model: "test",
|
||||
})
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing model", func(t *testing.T) {
|
||||
w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
|
||||
Text: "test text",
|
||||
})
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model not found", func(t *testing.T) {
|
||||
w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
|
||||
Model: "nonexistent",
|
||||
Text: "test text",
|
||||
})
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("successful tokenization", func(t *testing.T) {
|
||||
w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
|
||||
Model: "test",
|
||||
Text: "test text",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var resp api.TokenizeResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedTokens := []int{0, 1}
|
||||
if len(resp.Tokens) != len(expectedTokens) {
|
||||
t.Errorf("expected %d tokens, got %d", len(expectedTokens), len(resp.Tokens))
|
||||
}
|
||||
for i, token := range resp.Tokens {
|
||||
if token != expectedTokens[i] {
|
||||
t.Errorf("expected token %d at position %d, got %d", expectedTokens[i], i, token)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user