use float32
This commit is contained in:
parent
512e0a7bde
commit
00a4cb26ca
14
api/types.go
14
api/types.go
@ -210,7 +210,7 @@ type EmbedRequest struct {
|
|||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
|
||||||
// Input is the input to embed.
|
// Input is the input to embed.
|
||||||
Input any `json:"input,omitempty"`
|
Input any `json:"input"`
|
||||||
|
|
||||||
// KeepAlive controls how long the model will stay loaded in memory following
|
// KeepAlive controls how long the model will stay loaded in memory following
|
||||||
// this request.
|
// this request.
|
||||||
@ -222,6 +222,12 @@ type EmbedRequest struct {
|
|||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EmbedResponse is the response from [Client.Embed].
|
||||||
|
type EmbedResponse struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Embeddings [][]float32 `json:"embeddings,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// EmbeddingRequest is the request passed to [Client.Embeddings].
|
// EmbeddingRequest is the request passed to [Client.Embeddings].
|
||||||
type EmbeddingRequest struct {
|
type EmbeddingRequest struct {
|
||||||
// Model is the model name.
|
// Model is the model name.
|
||||||
@ -238,12 +244,6 @@ type EmbeddingRequest struct {
|
|||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// EmbedResponse is the response from [Client.Embed].
|
|
||||||
type EmbedResponse struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Embeddings [][]float64 `json:"embeddings,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// EmbeddingResponse is the response from [Client.Embeddings].
|
// EmbeddingResponse is the response from [Client.Embeddings].
|
||||||
type EmbeddingResponse struct {
|
type EmbeddingResponse struct {
|
||||||
Embedding []float64 `json:"embedding"`
|
Embedding []float64 `json:"embedding"`
|
||||||
|
@ -2,18 +2,18 @@ package format
|
|||||||
|
|
||||||
import "math"
|
import "math"
|
||||||
|
|
||||||
func Normalize(vec []float64) []float64 {
|
func Normalize(vec []float32) []float32 {
|
||||||
var sum float64
|
var sum float64
|
||||||
for _, v := range vec {
|
for _, v := range vec {
|
||||||
sum += v * v
|
sum += float64(v * v)
|
||||||
}
|
}
|
||||||
|
|
||||||
sum = math.Sqrt(sum)
|
sum = math.Sqrt(sum)
|
||||||
|
|
||||||
var norm float64
|
var norm float32
|
||||||
|
|
||||||
if sum > 0 {
|
if sum > 0 {
|
||||||
norm = 1.0 / sum
|
norm = float32(1.0 / sum)
|
||||||
} else {
|
} else {
|
||||||
norm = 0.0
|
norm = 0.0
|
||||||
}
|
}
|
||||||
|
@ -7,21 +7,21 @@ import (
|
|||||||
|
|
||||||
func TestNormalize(t *testing.T) {
|
func TestNormalize(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
input []float64
|
input []float32
|
||||||
}
|
}
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{input: []float64{1}},
|
{input: []float32{1}},
|
||||||
{input: []float64{0, 1, 2, 3}},
|
{input: []float32{0, 1, 2, 3}},
|
||||||
{input: []float64{0.1, 0.2, 0.3}},
|
{input: []float32{0.1, 0.2, 0.3}},
|
||||||
{input: []float64{-0.1, 0.2, 0.3, -0.4}},
|
{input: []float32{-0.1, 0.2, 0.3, -0.4}},
|
||||||
{input: []float64{0, 0, 0}},
|
{input: []float32{0, 0, 0}},
|
||||||
}
|
}
|
||||||
|
|
||||||
assertNorm := func(vec []float64) (res bool) {
|
assertNorm := func(vec []float32) (res bool) {
|
||||||
sum := 0.0
|
sum := 0.0
|
||||||
for _, v := range vec {
|
for _, v := range vec {
|
||||||
sum += v * v
|
sum += float64(v * v)
|
||||||
}
|
}
|
||||||
if math.Abs(sum-1) > 1e-6 {
|
if math.Abs(sum-1) > 1e-6 {
|
||||||
return sum == 0
|
return sum == 0
|
||||||
|
@ -34,7 +34,7 @@ type LlamaServer interface {
|
|||||||
WaitUntilRunning(ctx context.Context) error
|
WaitUntilRunning(ctx context.Context) error
|
||||||
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||||
Embedding(ctx context.Context, prompt string) ([]float64, error)
|
Embedding(ctx context.Context, prompt string) ([]float64, error)
|
||||||
Embed(ctx context.Context, input []string) ([][]float64, error)
|
Embed(ctx context.Context, input []string) ([][]float32, error)
|
||||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||||
Close() error
|
Close() error
|
||||||
@ -847,10 +847,10 @@ type EmbedRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type EmbedResponse struct {
|
type EmbedResponse struct {
|
||||||
Embedding [][]float64 `json:"embedding"`
|
Embedding [][]float32 `json:"embedding"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, error) {
|
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
|
||||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||||
slog.Error("Failed to acquire semaphore", "error", err)
|
slog.Error("Failed to acquire semaphore", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -414,12 +414,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
embeddings := [][]float64{}
|
embeddings := [][]float32{}
|
||||||
|
|
||||||
switch reqEmbed := req.Input.(type) {
|
switch reqEmbed := req.Input.(type) {
|
||||||
case string:
|
case string:
|
||||||
if reqEmbed == "" {
|
if reqEmbed == "" {
|
||||||
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
|
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
|
reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
|
||||||
@ -430,7 +430,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
|
embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
|
||||||
case []any:
|
case []any:
|
||||||
if reqEmbed == nil {
|
if reqEmbed == nil {
|
||||||
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
|
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -610,7 +610,7 @@ type mockLlm struct {
|
|||||||
completionResp error
|
completionResp error
|
||||||
embeddingResp []float64
|
embeddingResp []float64
|
||||||
embeddingRespErr error
|
embeddingRespErr error
|
||||||
embedResp [][]float64
|
embedResp [][]float32
|
||||||
embedRespErr error
|
embedRespErr error
|
||||||
tokenizeResp []int
|
tokenizeResp []int
|
||||||
tokenizeRespErr error
|
tokenizeRespErr error
|
||||||
@ -631,7 +631,7 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
|
|||||||
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
||||||
return s.embeddingResp, s.embeddingRespErr
|
return s.embeddingResp, s.embeddingRespErr
|
||||||
}
|
}
|
||||||
func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float64, error) {
|
func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) {
|
||||||
return s.embedResp, s.embedRespErr
|
return s.embedResp, s.embedRespErr
|
||||||
}
|
}
|
||||||
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user