embed text document in modelfile
This commit is contained in:
@@ -17,15 +17,18 @@ import (
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gonum.org/v1/gonum/mat"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/llama"
|
||||
"github.com/jmorganca/ollama/vector"
|
||||
)
|
||||
|
||||
var loaded struct {
|
||||
mu sync.Mutex
|
||||
|
||||
llm *llama.LLM
|
||||
llm *llama.LLM
|
||||
Embeddings []vector.Embedding
|
||||
|
||||
expireAt time.Time
|
||||
expireTimer *time.Timer
|
||||
@@ -72,6 +75,11 @@ func GenerateHandler(c *gin.Context) {
|
||||
loaded.digest = ""
|
||||
}
|
||||
|
||||
if model.Embeddings != nil && len(model.Embeddings) > 0 {
|
||||
opts.EmbeddingOnly = true // this is requried to generate embeddings, completions will still work
|
||||
loaded.Embeddings = model.Embeddings
|
||||
}
|
||||
|
||||
llm, err := llama.New(model.ModelPath, opts)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -101,7 +109,6 @@ func GenerateHandler(c *gin.Context) {
|
||||
loaded.digest = model.Digest
|
||||
loaded.options = opts
|
||||
}
|
||||
|
||||
sessionDuration := 5 * time.Minute
|
||||
|
||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||
@@ -127,7 +134,22 @@ func GenerateHandler(c *gin.Context) {
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
prompt, err := model.Prompt(req)
|
||||
embedding := ""
|
||||
if model.Embeddings != nil && len(model.Embeddings) > 0 {
|
||||
promptEmbed, err := loaded.llm.Embedding(req.Prompt)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
// TODO: set embed_top from specified parameters in modelfile
|
||||
embed_top := 3
|
||||
topK := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings)
|
||||
for _, e := range topK {
|
||||
embedding = fmt.Sprintf("%s %s", embedding, e.Embedding.Data)
|
||||
}
|
||||
}
|
||||
|
||||
prompt, err := model.Prompt(req, embedding)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user