This commit is contained in:
Roy Han 2024-06-28 15:26:58 -07:00
parent b9c74df37b
commit 5213c12354
3 changed files with 7 additions and 10 deletions

View File

@ -226,10 +226,7 @@ type EmbeddingRequest struct {
Model string `json:"model"`
// Prompt is the textual prompt to embed.
Prompt string `json:"prompt,omitempty"`
// PromptBatch is a list of prompts to embed.
PromptBatch []string `json:"prompt_batch,omitempty"`
Prompt string `json:"prompt"`
// KeepAlive controls how long the model will stay loaded in memory following
// this request.
@ -246,8 +243,7 @@ type EmbedResponse struct {
// EmbeddingResponse is the response from [Client.Embeddings].
type EmbeddingResponse struct {
Embedding []float64 `json:"embedding,omitempty"`
EmbeddingBatch [][]float64 `json:"embedding_batch,omitempty"`
Embedding []float64 `json:"embedding"`
}
// CreateRequest is the request passed to [Client.Create].

View File

@ -3156,14 +3156,14 @@ int main(int argc, char **argv) {
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
json input;
json prompt;
if (body.count("content") != 0)
{
input = body["content"];
prompt = body["content"];
}
else
{
input = "";
prompt = "";
}
// create and queue the task
@ -3171,7 +3171,7 @@ int main(int argc, char **argv) {
{
const int id_task = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(id_task);
llama.request_completion(id_task, {{"prompt", input}}, true, -1);
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
// get the result
task_result result = llama.queue_results.recv(id_task);

View File

@ -473,6 +473,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
return
}
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))