From 00a4cb26ca097b07d6aab3d043e61b8bf62e5341 Mon Sep 17 00:00:00 2001 From: Roy Han Date: Tue, 2 Jul 2024 10:30:29 -0700 Subject: [PATCH] use float32 --- api/types.go | 14 +++++++------- format/normalize.go | 8 ++++---- format/normalize_test.go | 16 ++++++++-------- llm/server.go | 6 +++--- server/routes.go | 6 +++--- server/sched_test.go | 4 ++-- 6 files changed, 27 insertions(+), 27 deletions(-) diff --git a/api/types.go b/api/types.go index 44e5c49ae..f39c2818d 100644 --- a/api/types.go +++ b/api/types.go @@ -210,7 +210,7 @@ type EmbedRequest struct { Model string `json:"model"` // Input is the input to embed. - Input any `json:"input,omitempty"` + Input any `json:"input"` // KeepAlive controls how long the model will stay loaded in memory following // this request. @@ -222,6 +222,12 @@ type EmbedRequest struct { Options map[string]interface{} `json:"options"` } +// EmbedResponse is the response from [Client.Embed]. +type EmbedResponse struct { + Model string `json:"model"` + Embeddings [][]float32 `json:"embeddings,omitempty"` +} + // EmbeddingRequest is the request passed to [Client.Embeddings]. type EmbeddingRequest struct { // Model is the model name. @@ -238,12 +244,6 @@ type EmbeddingRequest struct { Options map[string]interface{} `json:"options"` } -// EmbedResponse is the response from [Client.Embed]. -type EmbedResponse struct { - Model string `json:"model"` - Embeddings [][]float64 `json:"embeddings,omitempty"` -} - // EmbeddingResponse is the response from [Client.Embeddings]. type EmbeddingResponse struct { Embedding []float64 `json:"embedding"` diff --git a/format/normalize.go b/format/normalize.go index 15aa29b6a..67a42a933 100644 --- a/format/normalize.go +++ b/format/normalize.go @@ -2,18 +2,18 @@ package format import "math" -func Normalize(vec []float64) []float64 { +func Normalize(vec []float32) []float32 { var sum float64 for _, v := range vec { - sum += v * v + sum += float64(v * v) } sum = math.Sqrt(sum) - var norm float64 + var norm float32 if sum > 0 { - norm = 1.0 / sum + norm = float32(1.0 / sum) } else { norm = 0.0 } diff --git a/format/normalize_test.go b/format/normalize_test.go index fb18a1e6a..69e432ad3 100644 --- a/format/normalize_test.go +++ b/format/normalize_test.go @@ -7,21 +7,21 @@ import ( func TestNormalize(t *testing.T) { type testCase struct { - input []float64 + input []float32 } testCases := []testCase{ - {input: []float64{1}}, - {input: []float64{0, 1, 2, 3}}, - {input: []float64{0.1, 0.2, 0.3}}, - {input: []float64{-0.1, 0.2, 0.3, -0.4}}, - {input: []float64{0, 0, 0}}, + {input: []float32{1}}, + {input: []float32{0, 1, 2, 3}}, + {input: []float32{0.1, 0.2, 0.3}}, + {input: []float32{-0.1, 0.2, 0.3, -0.4}}, + {input: []float32{0, 0, 0}}, } - assertNorm := func(vec []float64) (res bool) { + assertNorm := func(vec []float32) (res bool) { sum := 0.0 for _, v := range vec { - sum += v * v + sum += float64(v * v) } if math.Abs(sum-1) > 1e-6 { return sum == 0 diff --git a/llm/server.go b/llm/server.go index 1d4ca460f..eec72ae49 100644 --- a/llm/server.go +++ b/llm/server.go @@ -34,7 +34,7 @@ type LlamaServer interface { WaitUntilRunning(ctx context.Context) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error Embedding(ctx context.Context, prompt string) ([]float64, error) - Embed(ctx context.Context, input []string) ([][]float64, error) + Embed(ctx context.Context, input []string) ([][]float32, error) Tokenize(ctx context.Context, content string) ([]int, error) Detokenize(ctx context.Context, tokens []int) (string, error) Close() error @@ -847,10 +847,10 @@ type EmbedRequest struct { } type EmbedResponse struct { - Embedding [][]float64 `json:"embedding"` + Embedding [][]float32 `json:"embedding"` } -func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, error) { +func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) { if err := s.sem.Acquire(ctx, 1); err != nil { slog.Error("Failed to acquire semaphore", "error", err) return nil, err diff --git a/server/routes.go b/server/routes.go index bd32c5faa..d493a8982 100644 --- a/server/routes.go +++ b/server/routes.go @@ -414,12 +414,12 @@ func (s *Server) EmbedHandler(c *gin.Context) { return s, nil } - embeddings := [][]float64{} + embeddings := [][]float32{} switch reqEmbed := req.Input.(type) { case string: if reqEmbed == "" { - c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}}) + c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}}) return } reqEmbed, err = checkFit(reqEmbed, *req.Truncate) @@ -430,7 +430,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed}) case []any: if reqEmbed == nil { - c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}}) + c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}}) return } diff --git a/server/sched_test.go b/server/sched_test.go index 1038b68dc..948347fee 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -610,7 +610,7 @@ type mockLlm struct { completionResp error embeddingResp []float64 embeddingRespErr error - embedResp [][]float64 + embedResp [][]float32 embedRespErr error tokenizeResp []int tokenizeRespErr error @@ -631,7 +631,7 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) { return s.embeddingResp, s.embeddingRespErr } -func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float64, error) { +func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) { return s.embedResp, s.embedRespErr } func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {