Merge pull request #462 from jmorganca/mxyng/rm-marshal-prompt

remove marshalPrompt which is no longer needed
This commit is contained in:
Michael Yang 2023-09-05 11:48:41 -07:00 committed by GitHub
commit 7b5aefb427
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 82 deletions

View File

@ -286,8 +286,8 @@ func newLlama(model string, adapters []string, runner ModelRunner, opts api.Opti
runner.Path, runner.Path,
append(params, "--port", strconv.Itoa(port))..., append(params, "--port", strconv.Itoa(port))...,
) )
var stderr bytes.Buffer cmd.Stdout = os.Stderr
cmd.Stderr = &stderr cmd.Stderr = os.Stderr
llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel}} llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel}}
@ -353,11 +353,6 @@ func (llm *llama) SetOptions(opts api.Options) {
llm.Options = opts llm.Options = opts
} }
type Prediction struct {
Content string `json:"content"`
Stop bool `json:"stop"`
}
type GenerationSettings struct { type GenerationSettings struct {
FrequencyPenalty float64 `json:"frequency_penalty"` FrequencyPenalty float64 `json:"frequency_penalty"`
IgnoreEOS bool `json:"ignore_eos"` IgnoreEOS bool `json:"ignore_eos"`
@ -385,31 +380,19 @@ type GenerationSettings struct {
} }
type Timings struct { type Timings struct {
PredictedMS float64 `json:"predicted_ms"` PredictedN int `json:"predicted_n"`
PredictedN int `json:"predicted_n"` PredictedMS float64 `json:"predicted_ms"`
PredictedPerSecond float64 `json:"predicted_per_second"` PromptN int `json:"prompt_n"`
PredictedPerTokenMS float64 `json:"predicted_per_token_ms"` PromptMS float64 `json:"prompt_ms"`
PromptMS float64 `json:"prompt_ms"`
PromptN int `json:"prompt_n"`
PromptPerSecond float64 `json:"prompt_per_second"`
PromptPerTokenMS float64 `json:"prompt_per_token_ms"`
} }
type PredictComplete struct { type Prediction struct {
Content string `json:"content"` Content string `json:"content"`
GenerationSettings GenerationSettings `json:"generation_settings"` Model string `json:"model"`
Model string `json:"model"` Prompt string `json:"prompt"`
Prompt string `json:"prompt"` Stop bool `json:"stop"`
Stop bool `json:"stop"`
StoppedEOS bool `json:"stopped_eos"` Timings `json:"timings"`
StoppedLimit bool `json:"stopped_limit"`
StoppedWord bool `json:"stopped_word"`
StoppingWord string `json:"stopping_word"`
Timings Timings `json:"timings"`
TokensCached int `json:"tokens_cached"`
TokensEvaluated int `json:"tokens_evaluated"`
TokensPredicted int `json:"tokens_predicted"`
Truncated bool `json:"truncated"`
} }
type PredictRequest struct { type PredictRequest struct {
@ -437,15 +420,19 @@ type PredictRequest struct {
Stop []string `json:"stop,omitempty"` Stop []string `json:"stop,omitempty"`
} }
func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string, fn func(api.GenerateResponse)) error { func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
// we need to find the trimmed prompt context before predicting so that we can return it to the client prevConvo, err := llm.Decode(ctx, prevContext)
trimmedPrompt, err := llm.marshalPrompt(ctx, predictCtx, prompt)
if err != nil { if err != nil {
return fmt.Errorf("marshaling prompt: %v", err) return err
} }
var nextContext strings.Builder
nextContext.WriteString(prevConvo)
nextContext.WriteString(prompt)
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port) endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
predReq := PredictRequest{ predReq := PredictRequest{
Prompt: trimmedPrompt, Prompt: nextContext.String(),
Stream: true, Stream: true,
NPredict: llm.NumPredict, NPredict: llm.NumPredict,
NKeep: llm.NumKeep, NKeep: llm.NumKeep,
@ -491,7 +478,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
} }
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
genCtx := trimmedPrompt // start with the trimmed prompt
for scanner.Scan() { for scanner.Scan() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -506,34 +492,31 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
// Read data from the server-side event stream // Read data from the server-side event stream
if strings.HasPrefix(line, "data: ") { if strings.HasPrefix(line, "data: ") {
evt := line[6:] evt := line[6:]
var complete PredictComplete var p Prediction
if err := json.Unmarshal([]byte(evt), &complete); err != nil { if err := json.Unmarshal([]byte(evt), &p); err != nil {
return fmt.Errorf("error unmarshaling llm complete response: %v", err) return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
} }
if complete.Timings.PredictedMS > 0 { fn(api.GenerateResponse{Response: p.Content})
genCtx += complete.Content nextContext.WriteString(p.Content)
embd, err := llm.Encode(ctx, genCtx)
if p.Stop {
embd, err := llm.Encode(ctx, nextContext.String())
if err != nil { if err != nil {
return fmt.Errorf("encoding context: %v", err) return fmt.Errorf("encoding context: %v", err)
} }
fn(api.GenerateResponse{ fn(api.GenerateResponse{
Done: true, Done: true,
Context: embd, Context: embd,
PromptEvalCount: int(complete.Timings.PromptN), PromptEvalCount: p.PromptN,
PromptEvalDuration: parseDurationMs(float64(complete.Timings.PromptMS)), PromptEvalDuration: parseDurationMs(p.PromptMS),
EvalCount: int(complete.Timings.PredictedN), EvalCount: p.PredictedN,
EvalDuration: parseDurationMs(float64(complete.Timings.PredictedMS)), EvalDuration: parseDurationMs(p.PredictedMS),
}) })
return nil return nil
} }
var pred Prediction
if err := json.Unmarshal([]byte(evt), &pred); err != nil {
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
}
genCtx += pred.Content
fn(api.GenerateResponse{Response: pred.Content})
} }
} }
} }
@ -545,34 +528,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
return nil return nil
} }
func (llm *llama) marshalPrompt(ctx context.Context, pCtx []int, prompt string) (string, error) {
pEncode, err := llm.Encode(ctx, prompt)
if err != nil {
return "", fmt.Errorf("encoding prompt context: %w", err)
}
tokens := append(pCtx, pEncode...)
if llm.NumKeep < 0 {
llm.NumKeep = len(tokens)
}
// min(llm.NumCtx - 4, llm.NumKeep)
if llm.NumCtx-4 < llm.NumKeep {
llm.NumKeep = llm.NumCtx - 4
}
if len(tokens) >= llm.NumCtx {
// truncate input
numLeft := (llm.NumCtx - llm.NumKeep) / 2
truncated := tokens[:llm.NumKeep]
erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft
truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...)
tokens = truncated
log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated))
}
return llm.Decode(ctx, tokens)
}
type TokenizeRequest struct { type TokenizeRequest struct {
Content string `json:"content"` Content string `json:"content"`
} }

View File

@ -117,12 +117,13 @@ func load(ctx context.Context, model *Model, reqOpts map[string]interface{}, ses
if err != nil { if err != nil {
return err return err
} }
tokensNoSystem, err := llmModel.Encode(ctx, promptNoSystem) tokensNoSystem, err := llmModel.Encode(ctx, promptNoSystem)
if err != nil { if err != nil {
return err return err
} }
opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1 opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem)
llmModel.SetOptions(opts) llmModel.SetOptions(opts)
} }