image embeddings

This commit is contained in:
Roy Han 2024-07-15 12:13:06 -07:00
parent 766ca1cd7d
commit eb7cc2d1ce
5 changed files with 58 additions and 28 deletions

View File

@ -187,6 +187,10 @@ type EmbedRequest struct {
Truncate *bool `json:"truncate,omitempty"` Truncate *bool `json:"truncate,omitempty"`
// Images is an optional list of base64-encoded images accompanying this
// request, for multimodal models.
Images []ImageData `json:"images,omitempty"`
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }

View File

@ -3192,12 +3192,22 @@ int main(int argc, char **argv) {
prompt = prompt[0]; prompt = prompt[0];
} }
json image_data;
if (body.count("image_data") != 0)
{
image_data = body["image_data"];
}
else {
image_data = "";
}
// TODO: prompt needs to represent the image data
// create and queue the task // create and queue the task
json responses; json responses;
{ {
const int id_task = llama.queue_tasks.get_new_id(); const int id_task = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(id_task); llama.queue_results.add_waiting_task_id(id_task);
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1); llama.request_completion(id_task, { {"prompt", prompt}, {"image_data", image_data} }, true, -1);
// get the result // get the result
task_result result = llama.queue_results.recv(id_task); task_result result = llama.queue_results.recv(id_task);

View File

@ -33,7 +33,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embed(ctx context.Context, input []string) ([][]float32, error) Embed(ctx context.Context, input []string, images []ImageData) ([][]float32, error)
Tokenize(ctx context.Context, content string) ([]int, error) Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error) Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error Close() error
@ -861,13 +861,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
type EmbedRequest struct { type EmbedRequest struct {
Content []string `json:"content"` Content []string `json:"content"`
Images []ImageData `json:"image_data"`
} }
type EmbedResponse struct { type EmbedResponse struct {
Embedding [][]float32 `json:"embedding"` Embedding [][]float32 `json:"embedding"`
} }
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) { func (s *llmServer) Embed(ctx context.Context, input []string, images []ImageData) ([][]float32, error) {
if err := s.sem.Acquire(ctx, 1); err != nil { if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err) slog.Error("Failed to acquire semaphore", "error", err)
return nil, err return nil, err
@ -882,7 +883,7 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, err
return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
} }
data, err := json.Marshal(EmbedRequest{Content: input}) data, err := json.Marshal(EmbedRequest{Content: input, Images: images})
if err != nil { if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err) return nil, fmt.Errorf("error marshaling embed data: %w", err)
} }

View File

@ -265,8 +265,16 @@ func (s *Server) EmbedHandler(c *gin.Context) {
truncate = false truncate = false
} }
inputCheck := true
if req.Images != nil {
inputCheck = false
}
var input []string var input []string
if inputCheck {
switch i := req.Input.(type) { switch i := req.Input.(type) {
case string: case string:
if len(i) > 0 { if len(i) > 0 {
@ -289,6 +297,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}}) c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return return
} }
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil { if err != nil {
@ -326,7 +335,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
input[i] = s input[i] = s
} }
embeddings, err := r.Embed(c.Request.Context(), input)
images := make([]llm.ImageData, len(req.Images))
for i := range req.Images {
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
}
embeddings, err := r.Embed(c.Request.Context(), input, images)
if err != nil { if err != nil {
slog.Error("embedding generation failed", "error", err) slog.Error("embedding generation failed", "error", err)
@ -384,7 +399,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt}) embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt}, nil)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) slog.Info(fmt.Sprintf("embedding generation failed: %v", err))

View File

@ -660,7 +660,7 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
return s.completionResp return s.completionResp
} }
func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) { func (s *mockLlm) Embed(ctx context.Context, input []string, images []llm.ImageData) ([][]float32, error) {
return s.embedResp, s.embedRespErr return s.embedResp, s.embedRespErr
} }
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) { func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {