set context length

This commit is contained in:
Roy Han 2024-07-10 15:21:46 -07:00
parent cdb9fe9b06
commit 694388db90

View File

@ -320,6 +320,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return return
} }
kvData, err := getKVData(model.ModelPath, false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
checkFit := func(s string, truncate bool) (string, error) { checkFit := func(s string, truncate bool) (string, error) {
tokens, err := r.Tokenize(c.Request.Context(), s) tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil { if err != nil {
@ -327,9 +333,10 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return "", err return "", err
} }
if len(tokens) > opts.NumCtx { ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if truncate { if truncate {
tokens = tokens[:opts.NumCtx] tokens = tokens[:ctxLen]
return r.Detokenize(c.Request.Context(), tokens) return r.Detokenize(c.Request.Context(), tokens)
} else { } else {
return "", fmt.Errorf("input length exceeds maximum context length") return "", fmt.Errorf("input length exceeds maximum context length")