diff --git a/llm/dyn_ext_server.go b/llm/dyn_ext_server.go index 832d3c47b..e690e9fa5 100644 --- a/llm/dyn_ext_server.go +++ b/llm/dyn_ext_server.go @@ -172,6 +172,19 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu slog.Info(fmt.Sprintf("loaded %d images", len(predict.Images))) } + // Limit the number of predictions to the maximum context length + // this will cause no more than two context shifts + // TODO: limit this further to num_ctx - len(prompt) to avoid + // any context shifts at all + if predict.Options.NumPredict > llm.options.NumCtx { + slog.Warn(fmt.Sprintf("requested num_predict is greater than the context length (%d > %d), using %d instead", predict.Options.NumPredict, llm.options.NumCtx, llm.options.NumCtx)) + predict.Options.NumPredict = llm.options.NumCtx + } + + if predict.Options.NumPredict == -1 { + predict.Options.NumPredict = llm.options.NumCtx + } + request := map[string]any{ "prompt": predict.Prompt, "stream": true,