playing around with truncate stuff

This commit is contained in:
Roy Han 2024-06-28 18:17:09 -07:00
parent c111d8bb51
commit 80c1a3f812
4 changed files with 16 additions and 1 deletions

View File

@ -216,6 +216,8 @@ type EmbedRequest struct {
// this request. // this request.
KeepAlive *Duration `json:"keep_alive,omitempty"` KeepAlive *Duration `json:"keep_alive,omitempty"`
Truncate *bool `json:"truncate,omitempty"`
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }

View File

@ -1206,6 +1206,7 @@ struct llama_server_context
res.result_json = json res.result_json = json
{ {
{"embedding", std::vector<float>(n_embd, 0.0f)}, {"embedding", std::vector<float>(n_embd, 0.0f)},
{"truncated", slot.truncated}
}; };
} }
else else
@ -1223,6 +1224,7 @@ struct llama_server_context
res.result_json = json res.result_json = json
{ {
{"embedding", std::vector<float>(n_embd, 0.0f)}, {"embedding", std::vector<float>(n_embd, 0.0f)},
{"truncated", slot.truncated}
}; };
continue; continue;
} }
@ -1231,6 +1233,7 @@ struct llama_server_context
res.result_json = json res.result_json = json
{ {
{"embedding", std::vector<float>(embd, embd + n_embd)}, {"embedding", std::vector<float>(embd, embd + n_embd)},
{"truncated", slot.truncated}
}; };
} }
} }
@ -3060,6 +3063,7 @@ int main(int argc, char **argv) {
if (!json_value(data, "stream", false)) { if (!json_value(data, "stream", false)) {
std::string completion_text; std::string completion_text;
task_result result = llama.queue_results.recv(task_id); task_result result = llama.queue_results.recv(task_id);
LOG_INFO("completion", {{"result", result.result_json}});
if (!result.error && result.stop) { if (!result.error && result.stop) {
res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
} }
@ -3075,6 +3079,7 @@ int main(int argc, char **argv) {
while (true) while (true)
{ {
task_result result = llama.queue_results.recv(task_id); task_result result = llama.queue_results.recv(task_id);
LOG_INFO("completion", {{"result", result.result_json}});
if (!result.error) { if (!result.error) {
const std::string str = const std::string str =
"data: " + "data: " +
@ -3180,6 +3185,7 @@ int main(int argc, char **argv) {
if (result.result_json.count("results")) { if (result.result_json.count("results")) {
// result for multi-task // result for multi-task
responses = result.result_json.at("results"); responses = result.result_json.at("results");
LOG_INFO("results", {result.result_json});
} else { } else {
// result for single task // result for single task
responses = std::vector<json>(1, result.result_json); responses = std::vector<json>(1, result.result_json);
@ -3198,6 +3204,8 @@ int main(int argc, char **argv) {
} }
// send the result // send the result
json result = json{{"embedding", embeddings}}; json result = json{{"embedding", embeddings}};
// log result
return res.set_content(result.dump(), "application/json; charset=utf-8"); return res.set_content(result.dump(), "application/json; charset=utf-8");
} else { } else {
// return error // return error

View File

@ -658,7 +658,7 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector<com
} }
// normalize a vector // normalize a vector
std::vector<float> normalize_vector(const std::vector<float>& vec, int size) { static std::vector<float> normalize_vector(const std::vector<float>& vec, int size) {
double sum = 0.0; double sum = 0.0;
for (float value : vec) { for (float value : vec) {
sum += value * value; sum += value * value;

View File

@ -356,6 +356,11 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return return
} }
if req.Truncate == nil {
truncate := true
req.Truncate = &truncate
}
model, err := GetModel(req.Model) model, err := GetModel(req.Model)
if err != nil { if err != nil {
var pErr *fs.PathError var pErr *fs.PathError