Compare commits
11 Commits
main
...
parth/toke
Author | SHA1 | Date | |
---|---|---|---|
![]() |
0ef4db0e24 | ||
![]() |
e3dd90102d | ||
![]() |
11acb85ff3 | ||
![]() |
1e545ea7a0 | ||
![]() |
e679885733 | ||
![]() |
f0a5f7994b | ||
![]() |
da35ad878b | ||
![]() |
a5e66a1163 | ||
![]() |
6fad1637ed | ||
![]() |
24613df094 | ||
![]() |
e60db349b7 |
@ -360,6 +360,24 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
|
|||||||
return &resp, nil
|
return &resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tokenize returns the tokens for a given text.
|
||||||
|
func (c *Client) Tokenize(ctx context.Context, req *TokenizeRequest) (*TokenizeResponse, error) {
|
||||||
|
var resp TokenizeResponse
|
||||||
|
if err := c.do(ctx, http.MethodPost, "/api/tokenize", req, &resp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detokenize returns the text for a given list of tokens.
|
||||||
|
func (c *Client) Detokenize(ctx context.Context, req *DetokenizeRequest) (*DetokenizeResponse, error) {
|
||||||
|
var resp DetokenizeResponse
|
||||||
|
if err := c.do(ctx, http.MethodPost, "/api/detokenize", req, &resp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
// CreateBlob creates a blob from a file on the server. digest is the
|
// CreateBlob creates a blob from a file on the server. digest is the
|
||||||
// expected SHA256 digest of the file, and r represents the file.
|
// expected SHA256 digest of the file, and r represents the file.
|
||||||
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
|
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
|
||||||
|
22
api/types.go
22
api/types.go
@ -293,6 +293,28 @@ type EmbeddingResponse struct {
|
|||||||
Embedding []float64 `json:"embedding"`
|
Embedding []float64 `json:"embedding"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TokenizeRequest is the request sent by [Client.Tokenize].
|
||||||
|
type TokenizeRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenizeResponse is the response from [Client.Tokenize].
|
||||||
|
type TokenizeResponse struct {
|
||||||
|
Tokens []int `json:"tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetokenizeRequest is the request sent by [Client.Detokenize].
|
||||||
|
type DetokenizeRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Tokens []int `json:"tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetokenizeResponse is the response from [Client.Detokenize].
|
||||||
|
type DetokenizeResponse struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
// CreateRequest is the request passed to [Client.Create].
|
// CreateRequest is the request passed to [Client.Create].
|
||||||
type CreateRequest struct {
|
type CreateRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
65
docs/api.md
65
docs/api.md
@ -13,6 +13,8 @@
|
|||||||
- [Push a Model](#push-a-model)
|
- [Push a Model](#push-a-model)
|
||||||
- [Generate Embeddings](#generate-embeddings)
|
- [Generate Embeddings](#generate-embeddings)
|
||||||
- [List Running Models](#list-running-models)
|
- [List Running Models](#list-running-models)
|
||||||
|
- [Tokenize Text](#tokenize-text)
|
||||||
|
- [Detokenize Tokens](#detokenize-tokens)
|
||||||
|
|
||||||
## Conventions
|
## Conventions
|
||||||
|
|
||||||
@ -1485,6 +1487,69 @@ A single JSON object will be returned.
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Tokenize Text
|
||||||
|
|
||||||
|
Tokenize text to an array of tokens using a specific model.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
POST /api/tokenize
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Parameters
|
||||||
|
|
||||||
|
- `model`: name of model to use for tokenization
|
||||||
|
- `text`: text to tokenize
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
#### Request
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -X POST http://localhost:11434/api/tokenize -d '{
|
||||||
|
"model": "llama3.2",
|
||||||
|
"text": "Why is the sky blue?"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"tokens": [10445,279,13180,374,6437,30]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Detokenize Tokens
|
||||||
|
|
||||||
|
Detokenize tokens to text using a specific model.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
POST /api/detokenize
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
|
||||||
|
- `model`: name of model to use for detokenization
|
||||||
|
- `tokens`: list of tokens to detokenize
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
#### Request
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -X POST http://localhost:11434/api/detokenize -d '{
|
||||||
|
"model": "llama3.2",
|
||||||
|
"tokens": [10445,374,279,13180,6437,30]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"text":"Why is the sky blue?"}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Generate Embedding
|
## Generate Embedding
|
||||||
|
|
||||||
> Note: this endpoint has been superseded by `/api/embed`
|
> Note: this endpoint has been superseded by `/api/embed`
|
||||||
|
@ -449,9 +449,24 @@ type Model struct {
|
|||||||
c *C.struct_llama_model
|
c *C.struct_llama_model
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) Detokenize(tokens []int) (string, error) {
|
||||||
|
var text string
|
||||||
|
for _, token := range tokens {
|
||||||
|
piece := m.TokenToPiece(token)
|
||||||
|
if piece == "" {
|
||||||
|
return "", fmt.Errorf("failed to convert token %d to piece", token)
|
||||||
|
}
|
||||||
|
text += piece
|
||||||
|
}
|
||||||
|
return text, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) TokenToPiece(token int) string {
|
func (m *Model) TokenToPiece(token int) string {
|
||||||
tokenLen := 12
|
tokenLen := 12
|
||||||
buf := make([]byte, tokenLen)
|
buf := make([]byte, tokenLen)
|
||||||
|
if token > m.NumVocab() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
tokenLen = int(C.llama_token_to_piece(
|
tokenLen = int(C.llama_token_to_piece(
|
||||||
m.c,
|
m.c,
|
||||||
C.int32_t(token),
|
C.int32_t(token),
|
||||||
|
67
server/model_loader.go
Normal file
67
server/model_loader.go
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/llama"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
type loadedModel struct {
|
||||||
|
model llama.Model
|
||||||
|
modelPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
type modelLoader struct {
|
||||||
|
cache sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelCache stores loaded models keyed by their full path and params hash
|
||||||
|
var modelCache sync.Map // map[string]*loadedModel
|
||||||
|
|
||||||
|
func (ml *modelLoader) LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
|
||||||
|
modelName := model.ParseName(name)
|
||||||
|
if !modelName.IsValid() {
|
||||||
|
return nil, fmt.Errorf("invalid model name: %s", modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
modelPath, err := GetModel(modelName.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("model not found: %s", modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create cache key from model path and params hash
|
||||||
|
cacheKey := fmt.Sprintf("%s-%+v", modelPath.ModelPath, params)
|
||||||
|
if cached, ok := modelCache.Load(cacheKey); ok {
|
||||||
|
return cached.(*loadedModel), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Evict existing model if any
|
||||||
|
ml.evictExistingModel()
|
||||||
|
|
||||||
|
model, err := llama.LoadModelFromFile(modelPath.ModelPath, params)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load model: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded := &loadedModel{
|
||||||
|
model: *model,
|
||||||
|
modelPath: modelPath.ModelPath,
|
||||||
|
}
|
||||||
|
modelCache.Store(cacheKey, loaded)
|
||||||
|
|
||||||
|
return loaded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// evictExistingModel removes any currently loaded model from the cache
|
||||||
|
// Currently only supports a single model in cache at a time
|
||||||
|
// TODO: Add proper cache eviction policy (LRU/size/TTL based)
|
||||||
|
func (ml *modelLoader) evictExistingModel() {
|
||||||
|
ml.cache.Range(func(key, value any) bool {
|
||||||
|
if cached, ok := ml.cache.LoadAndDelete(key); ok {
|
||||||
|
llama.FreeModel(&cached.(*loadedModel).model)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
103
server/routes.go
103
server/routes.go
@ -30,6 +30,7 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/discover"
|
"github.com/ollama/ollama/discover"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/llama"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
@ -46,6 +47,7 @@ var mode string = gin.DebugMode
|
|||||||
type Server struct {
|
type Server struct {
|
||||||
addr net.Addr
|
addr net.Addr
|
||||||
sched *Scheduler
|
sched *Scheduler
|
||||||
|
ml modelLoader
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@ -548,6 +550,105 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) TokenizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req api.TokenizeRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
http.Error(w, "missing request body", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Text == "" {
|
||||||
|
http.Error(w, "missing `text` for tokenization", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Model == "" {
|
||||||
|
http.Error(w, "missing `model` for tokenization", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
loadedModel, err := s.ml.LoadModel(req.Model, llama.ModelParams{
|
||||||
|
VocabOnly: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to load model: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tokenize the text
|
||||||
|
tokens, err := loadedModel.model.Tokenize(req.Text, false, true)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to tokenize text: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if err := json.NewEncoder(w).Encode(api.TokenizeResponse{
|
||||||
|
Tokens: tokens,
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) DetokenizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req api.DetokenizeRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
http.Error(w, "missing request body", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Tokens == nil {
|
||||||
|
http.Error(w, "missing tokens for detokenization", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Model == "" {
|
||||||
|
http.Error(w, "missing `model` for detokenization", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
loadedModel, err := s.ml.LoadModel(req.Model, llama.ModelParams{
|
||||||
|
VocabOnly: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to load model: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
text, err := loadedModel.model.Detokenize(req.Tokens)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to detokenize text: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if err := json.NewEncoder(w).Encode(api.DetokenizeResponse{
|
||||||
|
Text: text,
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) PullHandler(c *gin.Context) {
|
func (s *Server) PullHandler(c *gin.Context) {
|
||||||
var req api.PullRequest
|
var req api.PullRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
@ -1214,6 +1315,8 @@ func (s *Server) GenerateRoutes() http.Handler {
|
|||||||
r.POST("/api/chat", s.ChatHandler)
|
r.POST("/api/chat", s.ChatHandler)
|
||||||
r.POST("/api/embed", s.EmbedHandler)
|
r.POST("/api/embed", s.EmbedHandler)
|
||||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||||
|
r.Any("/api/tokenize", gin.WrapF(s.TokenizeHandler))
|
||||||
|
r.Any("/api/detokenize", gin.WrapF(s.DetokenizeHandler))
|
||||||
r.POST("/api/create", s.CreateHandler)
|
r.POST("/api/create", s.CreateHandler)
|
||||||
r.POST("/api/push", s.PushHandler)
|
r.POST("/api/push", s.PushHandler)
|
||||||
r.POST("/api/copy", s.CopyHandler)
|
r.POST("/api/copy", s.CopyHandler)
|
||||||
|
@ -46,6 +46,14 @@ func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (mockRunner) Detokenize(_ context.Context, tokens []int) (string, error) {
|
||||||
|
var strs []string
|
||||||
|
for _, t := range tokens {
|
||||||
|
strs = append(strs, fmt.Sprint(t))
|
||||||
|
}
|
||||||
|
return strings.Join(strs, " "), nil
|
||||||
|
}
|
||||||
|
|
||||||
func newMockServer(mock *mockRunner) func(discover.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
|
func newMockServer(mock *mockRunner) func(discover.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
|
||||||
return func(gpus discover.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
return func(gpus discover.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||||
return mock, nil
|
return mock, nil
|
||||||
|
290
server/routes_tokenization_test.go
Normal file
290
server/routes_tokenization_test.go
Normal file
@ -0,0 +1,290 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/discover"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTokenize(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
mock := mockRunner{
|
||||||
|
CompletionResponse: llm.CompletionResponse{
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
PromptEvalCount: 1,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
EvalCount: 1,
|
||||||
|
EvalDuration: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := Server{
|
||||||
|
sched: &Scheduler{
|
||||||
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
expiredCh: make(chan *runnerRef, 1),
|
||||||
|
unloadedCh: make(chan any, 1),
|
||||||
|
loaded: make(map[string]*runnerRef),
|
||||||
|
newServerFn: newMockServer(&mock),
|
||||||
|
getGpuFn: discover.GetGPUInfo,
|
||||||
|
getCpuFn: discover.GetCPUInfo,
|
||||||
|
reschedDelay: 250 * time.Millisecond,
|
||||||
|
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
|
||||||
|
// add small delay to simulate loading
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
req.successCh <- &runnerRef{
|
||||||
|
llama: &mock,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.sched.Run(context.TODO())
|
||||||
|
|
||||||
|
t.Run("missing body", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/api/tokenize", nil)
|
||||||
|
s.TokenizeHandler(w, r)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), "missing request body\n"); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing model", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/api/tokenize", strings.NewReader("{}"))
|
||||||
|
s.TokenizeHandler(w, r)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), "missing `text` for tokenization\n"); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("tokenize text", func(t *testing.T) {
|
||||||
|
// First create the model
|
||||||
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Modelfile: fmt.Sprintf(`FROM %s`, createBinFile(t, llm.KV{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"llama.block_count": uint32(1),
|
||||||
|
"llama.context_length": uint32(8192),
|
||||||
|
"llama.embedding_length": uint32(4096),
|
||||||
|
"llama.attention.head_count": uint32(32),
|
||||||
|
"llama.attention.head_count_kv": uint32(8),
|
||||||
|
"tokenizer.ggml.tokens": []string{""},
|
||||||
|
"tokenizer.ggml.scores": []float32{0},
|
||||||
|
"tokenizer.ggml.token_type": []int32{0},
|
||||||
|
}, []llm.Tensor{
|
||||||
|
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
})),
|
||||||
|
})
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("failed to create model: %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now test tokenization
|
||||||
|
body, err := json.Marshal(api.TokenizeRequest{
|
||||||
|
Model: "test",
|
||||||
|
Text: "Hello world how are you",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/api/tokenize", bytes.NewReader(body))
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
s.TokenizeHandler(w, r)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp api.TokenizeResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Errorf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Our mock tokenizer creates sequential tokens based on word count
|
||||||
|
expected := []int{0, 1, 2, 3, 4}
|
||||||
|
if diff := cmp.Diff(resp.Tokens, expected); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("tokenize empty text", func(t *testing.T) {
|
||||||
|
body, err := json.Marshal(api.TokenizeRequest{
|
||||||
|
Model: "test",
|
||||||
|
Text: "",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/api/tokenize", bytes.NewReader(body))
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
s.TokenizeHandler(w, r)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), "missing `text` for tokenization\n"); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetokenize(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
mock := mockRunner{
|
||||||
|
CompletionResponse: llm.CompletionResponse{
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
PromptEvalCount: 1,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
EvalCount: 1,
|
||||||
|
EvalDuration: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := Server{
|
||||||
|
sched: &Scheduler{
|
||||||
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
expiredCh: make(chan *runnerRef, 1),
|
||||||
|
unloadedCh: make(chan any, 1),
|
||||||
|
loaded: make(map[string]*runnerRef),
|
||||||
|
newServerFn: newMockServer(&mock),
|
||||||
|
getGpuFn: discover.GetGPUInfo,
|
||||||
|
getCpuFn: discover.GetCPUInfo,
|
||||||
|
reschedDelay: 250 * time.Millisecond,
|
||||||
|
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
|
||||||
|
// add small delay to simulate loading
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
req.successCh <- &runnerRef{
|
||||||
|
llama: &mock,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.sched.Run(context.TODO())
|
||||||
|
|
||||||
|
t.Run("detokenize tokens", func(t *testing.T) {
|
||||||
|
// Create the model first
|
||||||
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Modelfile: fmt.Sprintf(`FROM %s`, createBinFile(t, llm.KV{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"llama.block_count": uint32(1),
|
||||||
|
"llama.context_length": uint32(8192),
|
||||||
|
"llama.embedding_length": uint32(4096),
|
||||||
|
"llama.attention.head_count": uint32(32),
|
||||||
|
"llama.attention.head_count_kv": uint32(8),
|
||||||
|
"tokenizer.ggml.tokens": []string{""},
|
||||||
|
"tokenizer.ggml.scores": []float32{0},
|
||||||
|
"tokenizer.ggml.token_type": []int32{0},
|
||||||
|
}, []llm.Tensor{
|
||||||
|
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
})),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("failed to create model: %d - %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(api.DetokenizeRequest{
|
||||||
|
Model: "test",
|
||||||
|
Tokens: []int{0, 1, 2, 3, 4},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w = httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
s.DetokenizeHandler(w, r)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d: %s", w.Code, w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp api.DetokenizeResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Errorf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("detokenize empty tokens", func(t *testing.T) {
|
||||||
|
body, err := json.Marshal(api.DetokenizeRequest{
|
||||||
|
Model: "test",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
s.DetokenizeHandler(w, r)
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), "missing tokens for detokenization\n"); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("detokenize missing model", func(t *testing.T) {
|
||||||
|
body, err := json.Marshal(api.DetokenizeRequest{
|
||||||
|
Tokens: []int{0, 1, 2},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/api/detokenize", bytes.NewReader(body))
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
s.DetokenizeHandler(w, r)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("expected status 404, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), "model '' not found\n"); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
136
server/routes_tokenize_test.go
Normal file
136
server/routes_tokenize_test.go
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/discover"
|
||||||
|
"github.com/ollama/ollama/llama"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockModelLoader struct {
|
||||||
|
LoadModelFn func(string, llama.ModelParams) (*loadedModel, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ml *mockModelLoader) LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
|
||||||
|
if ml.LoadModelFn != nil {
|
||||||
|
return ml.LoadModelFn(name, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockModel struct {
|
||||||
|
llama.Model
|
||||||
|
TokenizeFn func(text string, addBos bool, addEos bool) ([]int, error)
|
||||||
|
TokenToPieceFn func(token int) string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mockModel) Tokenize(text string, addBos bool, addEos bool) ([]int, error) {
|
||||||
|
return []int{1, 2, 3}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mockModel) TokenToPiece(token int) string {
|
||||||
|
return fmt.Sprint(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenizeHandler(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
mockModel := mockModel{}
|
||||||
|
|
||||||
|
s := Server{
|
||||||
|
sched: &Scheduler{
|
||||||
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
expiredCh: make(chan *runnerRef, 1),
|
||||||
|
unloadedCh: make(chan any, 1),
|
||||||
|
loaded: make(map[string]*runnerRef),
|
||||||
|
newServerFn: newMockServer(&mockRunner{}),
|
||||||
|
getGpuFn: discover.GetGPUInfo,
|
||||||
|
getCpuFn: discover.GetCPUInfo,
|
||||||
|
reschedDelay: 250 * time.Millisecond,
|
||||||
|
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
req.successCh <- &runnerRef{
|
||||||
|
llama: &mockRunner{},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ml: mockLoader,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("method not allowed", func(t *testing.T) {
|
||||||
|
w := createRequest(t, gin.WrapF(s.TokenizeHandler), nil)
|
||||||
|
if w.Code != http.StatusMethodNotAllowed {
|
||||||
|
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing body", func(t *testing.T) {
|
||||||
|
w := createRequest(t, gin.WrapF(s.TokenizeHandler), nil)
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing text", func(t *testing.T) {
|
||||||
|
w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
|
||||||
|
Model: "test",
|
||||||
|
})
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing model", func(t *testing.T) {
|
||||||
|
w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
|
||||||
|
Text: "test text",
|
||||||
|
})
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("model not found", func(t *testing.T) {
|
||||||
|
w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
|
||||||
|
Model: "nonexistent",
|
||||||
|
Text: "test text",
|
||||||
|
})
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("successful tokenization", func(t *testing.T) {
|
||||||
|
w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
|
||||||
|
Model: "test",
|
||||||
|
Text: "test text",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp api.TokenizeResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedTokens := []int{0, 1}
|
||||||
|
if len(resp.Tokens) != len(expectedTokens) {
|
||||||
|
t.Errorf("expected %d tokens, got %d", len(expectedTokens), len(resp.Tokens))
|
||||||
|
}
|
||||||
|
for i, token := range resp.Tokens {
|
||||||
|
if token != expectedTokens[i] {
|
||||||
|
t.Errorf("expected token %d at position %d, got %d", expectedTokens[i], i, token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user