use float32
This commit is contained in:
parent
512e0a7bde
commit
00a4cb26ca
14
api/types.go
14
api/types.go
@ -210,7 +210,7 @@ type EmbedRequest struct {
|
||||
Model string `json:"model"`
|
||||
|
||||
// Input is the input to embed.
|
||||
Input any `json:"input,omitempty"`
|
||||
Input any `json:"input"`
|
||||
|
||||
// KeepAlive controls how long the model will stay loaded in memory following
|
||||
// this request.
|
||||
@ -222,6 +222,12 @@ type EmbedRequest struct {
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
// EmbedResponse is the response from [Client.Embed].
|
||||
type EmbedResponse struct {
|
||||
Model string `json:"model"`
|
||||
Embeddings [][]float32 `json:"embeddings,omitempty"`
|
||||
}
|
||||
|
||||
// EmbeddingRequest is the request passed to [Client.Embeddings].
|
||||
type EmbeddingRequest struct {
|
||||
// Model is the model name.
|
||||
@ -238,12 +244,6 @@ type EmbeddingRequest struct {
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
// EmbedResponse is the response from [Client.Embed].
|
||||
type EmbedResponse struct {
|
||||
Model string `json:"model"`
|
||||
Embeddings [][]float64 `json:"embeddings,omitempty"`
|
||||
}
|
||||
|
||||
// EmbeddingResponse is the response from [Client.Embeddings].
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
|
@ -2,18 +2,18 @@ package format
|
||||
|
||||
import "math"
|
||||
|
||||
func Normalize(vec []float64) []float64 {
|
||||
func Normalize(vec []float32) []float32 {
|
||||
var sum float64
|
||||
for _, v := range vec {
|
||||
sum += v * v
|
||||
sum += float64(v * v)
|
||||
}
|
||||
|
||||
sum = math.Sqrt(sum)
|
||||
|
||||
var norm float64
|
||||
var norm float32
|
||||
|
||||
if sum > 0 {
|
||||
norm = 1.0 / sum
|
||||
norm = float32(1.0 / sum)
|
||||
} else {
|
||||
norm = 0.0
|
||||
}
|
||||
|
@ -7,21 +7,21 @@ import (
|
||||
|
||||
func TestNormalize(t *testing.T) {
|
||||
type testCase struct {
|
||||
input []float64
|
||||
input []float32
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{input: []float64{1}},
|
||||
{input: []float64{0, 1, 2, 3}},
|
||||
{input: []float64{0.1, 0.2, 0.3}},
|
||||
{input: []float64{-0.1, 0.2, 0.3, -0.4}},
|
||||
{input: []float64{0, 0, 0}},
|
||||
{input: []float32{1}},
|
||||
{input: []float32{0, 1, 2, 3}},
|
||||
{input: []float32{0.1, 0.2, 0.3}},
|
||||
{input: []float32{-0.1, 0.2, 0.3, -0.4}},
|
||||
{input: []float32{0, 0, 0}},
|
||||
}
|
||||
|
||||
assertNorm := func(vec []float64) (res bool) {
|
||||
assertNorm := func(vec []float32) (res bool) {
|
||||
sum := 0.0
|
||||
for _, v := range vec {
|
||||
sum += v * v
|
||||
sum += float64(v * v)
|
||||
}
|
||||
if math.Abs(sum-1) > 1e-6 {
|
||||
return sum == 0
|
||||
|
@ -34,7 +34,7 @@ type LlamaServer interface {
|
||||
WaitUntilRunning(ctx context.Context) error
|
||||
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||
Embedding(ctx context.Context, prompt string) ([]float64, error)
|
||||
Embed(ctx context.Context, input []string) ([][]float64, error)
|
||||
Embed(ctx context.Context, input []string) ([][]float32, error)
|
||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||
Close() error
|
||||
@ -847,10 +847,10 @@ type EmbedRequest struct {
|
||||
}
|
||||
|
||||
type EmbedResponse struct {
|
||||
Embedding [][]float64 `json:"embedding"`
|
||||
Embedding [][]float32 `json:"embedding"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, error) {
|
||||
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
|
||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
return nil, err
|
||||
|
@ -414,12 +414,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
embeddings := [][]float64{}
|
||||
embeddings := [][]float32{}
|
||||
|
||||
switch reqEmbed := req.Input.(type) {
|
||||
case string:
|
||||
if reqEmbed == "" {
|
||||
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
|
||||
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}})
|
||||
return
|
||||
}
|
||||
reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
|
||||
@ -430,7 +430,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
|
||||
case []any:
|
||||
if reqEmbed == nil {
|
||||
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
|
||||
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}})
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -610,7 +610,7 @@ type mockLlm struct {
|
||||
completionResp error
|
||||
embeddingResp []float64
|
||||
embeddingRespErr error
|
||||
embedResp [][]float64
|
||||
embedResp [][]float32
|
||||
embedRespErr error
|
||||
tokenizeResp []int
|
||||
tokenizeRespErr error
|
||||
@ -631,7 +631,7 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
|
||||
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
||||
return s.embeddingResp, s.embeddingRespErr
|
||||
}
|
||||
func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float64, error) {
|
||||
func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) {
|
||||
return s.embedResp, s.embedRespErr
|
||||
}
|
||||
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user