playing around with truncate stuff
This commit is contained in:
parent
c111d8bb51
commit
80c1a3f812
@ -216,6 +216,8 @@ type EmbedRequest struct {
|
|||||||
// this request.
|
// this request.
|
||||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
|
|
||||||
|
Truncate *bool `json:"truncate,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
}
|
}
|
||||||
|
8
llm/ext_server/server.cpp
vendored
8
llm/ext_server/server.cpp
vendored
@ -1206,6 +1206,7 @@ struct llama_server_context
|
|||||||
res.result_json = json
|
res.result_json = json
|
||||||
{
|
{
|
||||||
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
||||||
|
{"truncated", slot.truncated}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@ -1223,6 +1224,7 @@ struct llama_server_context
|
|||||||
res.result_json = json
|
res.result_json = json
|
||||||
{
|
{
|
||||||
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
||||||
|
{"truncated", slot.truncated}
|
||||||
};
|
};
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -1231,6 +1233,7 @@ struct llama_server_context
|
|||||||
res.result_json = json
|
res.result_json = json
|
||||||
{
|
{
|
||||||
{"embedding", std::vector<float>(embd, embd + n_embd)},
|
{"embedding", std::vector<float>(embd, embd + n_embd)},
|
||||||
|
{"truncated", slot.truncated}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3060,6 +3063,7 @@ int main(int argc, char **argv) {
|
|||||||
if (!json_value(data, "stream", false)) {
|
if (!json_value(data, "stream", false)) {
|
||||||
std::string completion_text;
|
std::string completion_text;
|
||||||
task_result result = llama.queue_results.recv(task_id);
|
task_result result = llama.queue_results.recv(task_id);
|
||||||
|
LOG_INFO("completion", {{"result", result.result_json}});
|
||||||
if (!result.error && result.stop) {
|
if (!result.error && result.stop) {
|
||||||
res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
|
res.set_content(result.result_json.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
|
||||||
}
|
}
|
||||||
@ -3075,6 +3079,7 @@ int main(int argc, char **argv) {
|
|||||||
while (true)
|
while (true)
|
||||||
{
|
{
|
||||||
task_result result = llama.queue_results.recv(task_id);
|
task_result result = llama.queue_results.recv(task_id);
|
||||||
|
LOG_INFO("completion", {{"result", result.result_json}});
|
||||||
if (!result.error) {
|
if (!result.error) {
|
||||||
const std::string str =
|
const std::string str =
|
||||||
"data: " +
|
"data: " +
|
||||||
@ -3180,6 +3185,7 @@ int main(int argc, char **argv) {
|
|||||||
if (result.result_json.count("results")) {
|
if (result.result_json.count("results")) {
|
||||||
// result for multi-task
|
// result for multi-task
|
||||||
responses = result.result_json.at("results");
|
responses = result.result_json.at("results");
|
||||||
|
LOG_INFO("results", {result.result_json});
|
||||||
} else {
|
} else {
|
||||||
// result for single task
|
// result for single task
|
||||||
responses = std::vector<json>(1, result.result_json);
|
responses = std::vector<json>(1, result.result_json);
|
||||||
@ -3198,6 +3204,8 @@ int main(int argc, char **argv) {
|
|||||||
}
|
}
|
||||||
// send the result
|
// send the result
|
||||||
json result = json{{"embedding", embeddings}};
|
json result = json{{"embedding", embeddings}};
|
||||||
|
// log result
|
||||||
|
|
||||||
return res.set_content(result.dump(), "application/json; charset=utf-8");
|
return res.set_content(result.dump(), "application/json; charset=utf-8");
|
||||||
} else {
|
} else {
|
||||||
// return error
|
// return error
|
||||||
|
2
llm/ext_server/utils.hpp
vendored
2
llm/ext_server/utils.hpp
vendored
@ -658,7 +658,7 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector<com
|
|||||||
}
|
}
|
||||||
|
|
||||||
// normalize a vector
|
// normalize a vector
|
||||||
std::vector<float> normalize_vector(const std::vector<float>& vec, int size) {
|
static std::vector<float> normalize_vector(const std::vector<float>& vec, int size) {
|
||||||
double sum = 0.0;
|
double sum = 0.0;
|
||||||
for (float value : vec) {
|
for (float value : vec) {
|
||||||
sum += value * value;
|
sum += value * value;
|
||||||
|
@ -356,6 +356,11 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if req.Truncate == nil {
|
||||||
|
truncate := true
|
||||||
|
req.Truncate = &truncate
|
||||||
|
}
|
||||||
|
|
||||||
model, err := GetModel(req.Model)
|
model, err := GetModel(req.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var pErr *fs.PathError
|
var pErr *fs.PathError
|
||||||
|
Loading…
x
Reference in New Issue
Block a user