From ca7c3f7e0f9b14ea435aedef1c448a37a44f423c Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Tue, 12 Mar 2024 21:17:25 -0700 Subject: [PATCH] limit `num_predict` to `num_ctx` --- llm/dyn_ext_server.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/llm/dyn_ext_server.go b/llm/dyn_ext_server.go index 832d3c47b..e690e9fa5 100644 --- a/llm/dyn_ext_server.go +++ b/llm/dyn_ext_server.go @@ -172,6 +172,19 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu slog.Info(fmt.Sprintf("loaded %d images", len(predict.Images))) } + // Limit the number of predictions to the maximum context length + // this will cause no more than two context shifts + // TODO: limit this further to num_ctx - len(prompt) to avoid + // any context shifts at all + if predict.Options.NumPredict > llm.options.NumCtx { + slog.Warn(fmt.Sprintf("requested num_predict is greater than the context length (%d > %d), using %d instead", predict.Options.NumPredict, llm.options.NumCtx, llm.options.NumCtx)) + predict.Options.NumPredict = llm.options.NumCtx + } + + if predict.Options.NumPredict == -1 { + predict.Options.NumPredict = llm.options.NumCtx + } + request := map[string]any{ "prompt": predict.Prompt, "stream": true,