ml: Add support for quantized KV cache

Similar to the llama engine, quantizing the KV cache requires
flash attention to be enabled through the Ollama server.
This commit is contained in:
Jesse Gross 2025-02-21 20:54:14 -08:00 committed by Jesse Gross
parent f52b2615ef
commit 4100ed7bdd
3 changed files with 13 additions and 3 deletions

View File

@ -215,7 +215,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string { return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
}) })
case DTypeF16: case DTypeF16, DTypeQ80, DTypeQ40:
f32 := ctx.Empty(DTypeF32, t.Shape()...) f32 := ctx.Empty(DTypeF32, t.Shape()...)
f32 = t.Copy(ctx, f32) f32 = t.Copy(ctx, f32)
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string { return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
@ -283,5 +283,7 @@ const (
DTypeOther DType = iota DTypeOther DType = iota
DTypeF32 DTypeF32
DTypeF16 DTypeF16
DTypeQ80
DTypeQ40
DTypeI32 DTypeI32
) )

View File

@ -535,6 +535,10 @@ func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
cdtype = C.GGML_TYPE_F32 cdtype = C.GGML_TYPE_F32
case ml.DTypeF16: case ml.DTypeF16:
cdtype = C.GGML_TYPE_F16 cdtype = C.GGML_TYPE_F16
case ml.DTypeQ80:
cdtype = C.GGML_TYPE_Q8_0
case ml.DTypeQ40:
cdtype = C.GGML_TYPE_Q4_0
case ml.DTypeI32: case ml.DTypeI32:
cdtype = C.GGML_TYPE_I32 cdtype = C.GGML_TYPE_I32
default: default:
@ -680,6 +684,10 @@ func (t *Tensor) DType() ml.DType {
return ml.DTypeF32 return ml.DTypeF32
case C.GGML_TYPE_F16: case C.GGML_TYPE_F16:
return ml.DTypeF16 return ml.DTypeF16
case C.GGML_TYPE_Q8_0:
return ml.DTypeQ80
case C.GGML_TYPE_Q4_0:
return ml.DTypeQ40
case C.GGML_TYPE_I32: case C.GGML_TYPE_I32:
return ml.DTypeI32 return ml.DTypeI32
default: default:

View File

@ -58,9 +58,9 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
func kvCacheTypeFromStr(s string) ml.DType { func kvCacheTypeFromStr(s string) ml.DType {
switch s { switch s {
case "q8_0": case "q8_0":
panic("kv cache quantization not yet implemented") return ml.DTypeQ80
case "q4_0": case "q4_0":
panic("kv cache quantization not yet implemented") return ml.DTypeQ40
default: default:
return ml.DTypeF16 return ml.DTypeF16
} }