use float32

This commit is contained in:
Roy Han 2024-07-02 10:30:29 -07:00
parent 512e0a7bde
commit 00a4cb26ca
6 changed files with 27 additions and 27 deletions

View File

@ -210,7 +210,7 @@ type EmbedRequest struct {
Model string `json:"model"`
// Input is the input to embed.
Input any `json:"input,omitempty"`
Input any `json:"input"`
// KeepAlive controls how long the model will stay loaded in memory following
// this request.
@ -222,6 +222,12 @@ type EmbedRequest struct {
Options map[string]interface{} `json:"options"`
}
// EmbedResponse is the response from [Client.Embed].
type EmbedResponse struct {
Model string `json:"model"`
Embeddings [][]float32 `json:"embeddings,omitempty"`
}
// EmbeddingRequest is the request passed to [Client.Embeddings].
type EmbeddingRequest struct {
// Model is the model name.
@ -238,12 +244,6 @@ type EmbeddingRequest struct {
Options map[string]interface{} `json:"options"`
}
// EmbedResponse is the response from [Client.Embed].
type EmbedResponse struct {
Model string `json:"model"`
Embeddings [][]float64 `json:"embeddings,omitempty"`
}
// EmbeddingResponse is the response from [Client.Embeddings].
type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`

View File

@ -2,18 +2,18 @@ package format
import "math"
func Normalize(vec []float64) []float64 {
func Normalize(vec []float32) []float32 {
var sum float64
for _, v := range vec {
sum += v * v
sum += float64(v * v)
}
sum = math.Sqrt(sum)
var norm float64
var norm float32
if sum > 0 {
norm = 1.0 / sum
norm = float32(1.0 / sum)
} else {
norm = 0.0
}

View File

@ -7,21 +7,21 @@ import (
func TestNormalize(t *testing.T) {
type testCase struct {
input []float64
input []float32
}
testCases := []testCase{
{input: []float64{1}},
{input: []float64{0, 1, 2, 3}},
{input: []float64{0.1, 0.2, 0.3}},
{input: []float64{-0.1, 0.2, 0.3, -0.4}},
{input: []float64{0, 0, 0}},
{input: []float32{1}},
{input: []float32{0, 1, 2, 3}},
{input: []float32{0.1, 0.2, 0.3}},
{input: []float32{-0.1, 0.2, 0.3, -0.4}},
{input: []float32{0, 0, 0}},
}
assertNorm := func(vec []float64) (res bool) {
assertNorm := func(vec []float32) (res bool) {
sum := 0.0
for _, v := range vec {
sum += v * v
sum += float64(v * v)
}
if math.Abs(sum-1) > 1e-6 {
return sum == 0

View File

@ -34,7 +34,7 @@ type LlamaServer interface {
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, prompt string) ([]float64, error)
Embed(ctx context.Context, input []string) ([][]float64, error)
Embed(ctx context.Context, input []string) ([][]float32, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
@ -847,10 +847,10 @@ type EmbedRequest struct {
}
type EmbedResponse struct {
Embedding [][]float64 `json:"embedding"`
Embedding [][]float32 `json:"embedding"`
}
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, error) {
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return nil, err

View File

@ -414,12 +414,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return s, nil
}
embeddings := [][]float64{}
embeddings := [][]float32{}
switch reqEmbed := req.Input.(type) {
case string:
if reqEmbed == "" {
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}})
return
}
reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
@ -430,7 +430,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
case []any:
if reqEmbed == nil {
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}})
return
}

View File

@ -610,7 +610,7 @@ type mockLlm struct {
completionResp error
embeddingResp []float64
embeddingRespErr error
embedResp [][]float64
embedResp [][]float32
embedRespErr error
tokenizeResp []int
tokenizeRespErr error
@ -631,7 +631,7 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
return s.embeddingResp, s.embeddingRespErr
}
func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float64, error) {
func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) {
return s.embedResp, s.embedRespErr
}
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {