use float32

This commit is contained in:
Roy Han 2024-07-02 10:30:29 -07:00
parent 512e0a7bde
commit 00a4cb26ca
6 changed files with 27 additions and 27 deletions

View File

@ -210,7 +210,7 @@ type EmbedRequest struct {
Model string `json:"model"` Model string `json:"model"`
// Input is the input to embed. // Input is the input to embed.
Input any `json:"input,omitempty"` Input any `json:"input"`
// KeepAlive controls how long the model will stay loaded in memory following // KeepAlive controls how long the model will stay loaded in memory following
// this request. // this request.
@ -222,6 +222,12 @@ type EmbedRequest struct {
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }
// EmbedResponse is the response from [Client.Embed].
type EmbedResponse struct {
Model string `json:"model"`
Embeddings [][]float32 `json:"embeddings,omitempty"`
}
// EmbeddingRequest is the request passed to [Client.Embeddings]. // EmbeddingRequest is the request passed to [Client.Embeddings].
type EmbeddingRequest struct { type EmbeddingRequest struct {
// Model is the model name. // Model is the model name.
@ -238,12 +244,6 @@ type EmbeddingRequest struct {
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }
// EmbedResponse is the response from [Client.Embed].
type EmbedResponse struct {
Model string `json:"model"`
Embeddings [][]float64 `json:"embeddings,omitempty"`
}
// EmbeddingResponse is the response from [Client.Embeddings]. // EmbeddingResponse is the response from [Client.Embeddings].
type EmbeddingResponse struct { type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"` Embedding []float64 `json:"embedding"`

View File

@ -2,18 +2,18 @@ package format
import "math" import "math"
func Normalize(vec []float64) []float64 { func Normalize(vec []float32) []float32 {
var sum float64 var sum float64
for _, v := range vec { for _, v := range vec {
sum += v * v sum += float64(v * v)
} }
sum = math.Sqrt(sum) sum = math.Sqrt(sum)
var norm float64 var norm float32
if sum > 0 { if sum > 0 {
norm = 1.0 / sum norm = float32(1.0 / sum)
} else { } else {
norm = 0.0 norm = 0.0
} }

View File

@ -7,21 +7,21 @@ import (
func TestNormalize(t *testing.T) { func TestNormalize(t *testing.T) {
type testCase struct { type testCase struct {
input []float64 input []float32
} }
testCases := []testCase{ testCases := []testCase{
{input: []float64{1}}, {input: []float32{1}},
{input: []float64{0, 1, 2, 3}}, {input: []float32{0, 1, 2, 3}},
{input: []float64{0.1, 0.2, 0.3}}, {input: []float32{0.1, 0.2, 0.3}},
{input: []float64{-0.1, 0.2, 0.3, -0.4}}, {input: []float32{-0.1, 0.2, 0.3, -0.4}},
{input: []float64{0, 0, 0}}, {input: []float32{0, 0, 0}},
} }
assertNorm := func(vec []float64) (res bool) { assertNorm := func(vec []float32) (res bool) {
sum := 0.0 sum := 0.0
for _, v := range vec { for _, v := range vec {
sum += v * v sum += float64(v * v)
} }
if math.Abs(sum-1) > 1e-6 { if math.Abs(sum-1) > 1e-6 {
return sum == 0 return sum == 0

View File

@ -34,7 +34,7 @@ type LlamaServer interface {
WaitUntilRunning(ctx context.Context) error WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, prompt string) ([]float64, error) Embedding(ctx context.Context, prompt string) ([]float64, error)
Embed(ctx context.Context, input []string) ([][]float64, error) Embed(ctx context.Context, input []string) ([][]float32, error)
Tokenize(ctx context.Context, content string) ([]int, error) Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error) Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error Close() error
@ -847,10 +847,10 @@ type EmbedRequest struct {
} }
type EmbedResponse struct { type EmbedResponse struct {
Embedding [][]float64 `json:"embedding"` Embedding [][]float32 `json:"embedding"`
} }
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, error) { func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
if err := s.sem.Acquire(ctx, 1); err != nil { if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err) slog.Error("Failed to acquire semaphore", "error", err)
return nil, err return nil, err

View File

@ -414,12 +414,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return s, nil return s, nil
} }
embeddings := [][]float64{} embeddings := [][]float32{}
switch reqEmbed := req.Input.(type) { switch reqEmbed := req.Input.(type) {
case string: case string:
if reqEmbed == "" { if reqEmbed == "" {
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}}) c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}})
return return
} }
reqEmbed, err = checkFit(reqEmbed, *req.Truncate) reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
@ -430,7 +430,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed}) embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
case []any: case []any:
if reqEmbed == nil { if reqEmbed == nil {
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}}) c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}})
return return
} }

View File

@ -610,7 +610,7 @@ type mockLlm struct {
completionResp error completionResp error
embeddingResp []float64 embeddingResp []float64
embeddingRespErr error embeddingRespErr error
embedResp [][]float64 embedResp [][]float32
embedRespErr error embedRespErr error
tokenizeResp []int tokenizeResp []int
tokenizeRespErr error tokenizeRespErr error
@ -631,7 +631,7 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) { func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
return s.embeddingResp, s.embeddingRespErr return s.embeddingResp, s.embeddingRespErr
} }
func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float64, error) { func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) {
return s.embedResp, s.embedRespErr return s.embedResp, s.embedRespErr
} }
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) { func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {