image embeddings

This commit is contained in:
Roy Han 2024-07-15 12:13:06 -07:00
parent 766ca1cd7d
commit eb7cc2d1ce
5 changed files with 58 additions and 28 deletions

View File

@ -187,6 +187,10 @@ type EmbedRequest struct {
Truncate *bool `json:"truncate,omitempty"`
// Images is an optional list of base64-encoded images accompanying this
// request, for multimodal models.
Images []ImageData `json:"images,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
}

View File

@ -3192,12 +3192,22 @@ int main(int argc, char **argv) {
prompt = prompt[0];
}
json image_data;
if (body.count("image_data") != 0)
{
image_data = body["image_data"];
}
else {
image_data = "";
}
// TODO: prompt needs to represent the image data
// create and queue the task
json responses;
{
const int id_task = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(id_task);
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
llama.request_completion(id_task, { {"prompt", prompt}, {"image_data", image_data} }, true, -1);
// get the result
task_result result = llama.queue_results.recv(id_task);

View File

@ -33,7 +33,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embed(ctx context.Context, input []string) ([][]float32, error)
Embed(ctx context.Context, input []string, images []ImageData) ([][]float32, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
@ -860,14 +860,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}
type EmbedRequest struct {
Content []string `json:"content"`
Content []string `json:"content"`
Images []ImageData `json:"image_data"`
}
type EmbedResponse struct {
Embedding [][]float32 `json:"embedding"`
}
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
func (s *llmServer) Embed(ctx context.Context, input []string, images []ImageData) ([][]float32, error) {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return nil, err
@ -882,7 +883,7 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, err
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(EmbedRequest{Content: input})
data, err := json.Marshal(EmbedRequest{Content: input, Images: images})
if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
}

View File

@ -265,29 +265,38 @@ func (s *Server) EmbedHandler(c *gin.Context) {
truncate = false
}
var input []string
inputCheck := true
switch i := req.Input.(type) {
case string:
if len(i) > 0 {
input = append(input, i)
}
case []any:
for _, v := range i {
if _, ok := v.(string); !ok {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
input = append(input, v.(string))
}
default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
if req.Images != nil {
inputCheck = false
}
if len(input) == 0 {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
var input []string
if inputCheck {
switch i := req.Input.(type) {
case string:
if len(i) > 0 {
input = append(input, i)
}
case []any:
for _, v := range i {
if _, ok := v.(string); !ok {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
input = append(input, v.(string))
}
default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
if len(input) == 0 {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
@ -326,7 +335,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
input[i] = s
}
embeddings, err := r.Embed(c.Request.Context(), input)
images := make([]llm.ImageData, len(req.Images))
for i := range req.Images {
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
}
embeddings, err := r.Embed(c.Request.Context(), input, images)
if err != nil {
slog.Error("embedding generation failed", "error", err)
@ -384,7 +399,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt}, nil)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))

View File

@ -660,7 +660,7 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
return s.completionResp
}
func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) {
func (s *mockLlm) Embed(ctx context.Context, input []string, images []llm.ImageData) ([][]float32, error) {
return s.embedResp, s.embedRespErr
}
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {