From 6ed88985903be474ecd59992f7191c2f0fa87e36 Mon Sep 17 00:00:00 2001 From: Devon Rifkin Date: Fri, 25 Apr 2025 16:16:15 -0700 Subject: [PATCH] ggml: fix crash for array head counts If it's an array, it uses the max value in the array If array values for head counts becomes more popular, we can consider a more invasive change like #10225 to calculate more accurate estimates. Fixes: #9984 --- fs/ggml/ggml.go | 103 +++++++++++++++++++++++++++++++------------ fs/ggml/ggml_test.go | 30 +++++++++++++ llm/memory.go | 7 ++- 3 files changed, 110 insertions(+), 30 deletions(-) diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 427a43aec..0d38f29e8 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -33,7 +33,8 @@ func (kv KV) Kind() string { } func (kv KV) ParameterCount() uint64 { - return keyValue(kv, "general.parameter_count", uint64(0)) + val, _ := keyValue(kv, "general.parameter_count", uint64(0)) + return val } func (kv KV) FileType() fileType { @@ -52,16 +53,27 @@ func (kv KV) EmbeddingLength() uint64 { return uint64(kv.Uint("embedding_length")) } -func (kv KV) HeadCount() uint64 { - return uint64(kv.Uint("attention.head_count")) +func (kv KV) HeadCountMax() uint64 { + // TODO(drifkin): using the max value can cause an overestimation. In the + // future if array values become more popular, we can adapt the more invasive + // + return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1)) } -func (kv KV) HeadCountKV() uint64 { - return uint64(kv.Uint("attention.head_count_kv", 1)) +func (kv KV) HeadCountMin() uint64 { + return uint64(kv.UintOrMinArrayValue("attention.head_count", 1)) } -func (kv KV) EmbeddingHeadCount() uint64 { - if heads := kv.HeadCount(); heads > 0 { +func (kv KV) HeadCountKVMax() uint64 { + return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1)) +} + +func (kv KV) HeadCountKVMin() uint64 { + return uint64(kv.UintOrMinArrayValue("attention.head_count_kv", 1)) +} + +func (kv KV) EmbeddingHeadCountMax() uint64 { + if heads := kv.HeadCountMin(); heads > 0 { return kv.EmbeddingLength() / heads } @@ -69,15 +81,11 @@ func (kv KV) EmbeddingHeadCount() uint64 { } func (kv KV) EmbeddingHeadCountK() uint64 { - return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount()))) + return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCountMax()))) } func (kv KV) EmbeddingHeadCountV() uint64 { - return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount()))) -} - -func (kv KV) GQA() uint64 { - return kv.HeadCount() / kv.HeadCountKV() + return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCountMax()))) } func (kv KV) ContextLength() uint64 { @@ -89,35 +97,72 @@ func (kv KV) ChatTemplate() string { } func (kv KV) String(key string, defaultValue ...string) string { - return keyValue(kv, key, append(defaultValue, "")...) + val, _ := keyValue(kv, key, append(defaultValue, "")...) + return val } func (kv KV) Uint(key string, defaultValue ...uint32) uint32 { - return keyValue(kv, key, append(defaultValue, 0)...) + val, _ := keyValue(kv, key, append(defaultValue, 0)...) + return val } func (kv KV) Float(key string, defaultValue ...float32) float32 { - return keyValue(kv, key, append(defaultValue, 0)...) + val, _ := keyValue(kv, key, append(defaultValue, 0)...) + return val } func (kv KV) Bool(key string, defaultValue ...bool) bool { - return keyValue(kv, key, append(defaultValue, false)...) + val, _ := keyValue(kv, key, append(defaultValue, false)...) + return val +} + +func (kv KV) UintOrMaxArrayValue(key string, defaultValue uint32) uint32 { + _, max := kv.UintOrArrayValue(key, defaultValue) + return max +} + +func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 { + min, _ := kv.UintOrArrayValue(key, defaultValue) + return min +} + +func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) { + if u32, ok := keyValue(kv, key, uint32(0)); ok { + return u32, u32 + } else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok { + min := slices.Min(u32s.values) + max := slices.Max(u32s.values) + return min, max + } else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok { + min := slices.Min(i32s.values) + max := slices.Max(i32s.values) + if min < 0 || max < 0 { + slog.Warn("array values are unexpectedly negative", "key", key, "min", min, "max", max) + } + return uint32(min), uint32(max) + } + + return defaultValue, defaultValue } func (kv KV) Strings(key string, defaultValue ...[]string) []string { - return keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]}).values + val, _ := keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]}) + return val.values } func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 { - return keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]}).values + val, _ := keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]}) + return val.values } func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 { - return keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]}).values + val, _ := keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]}) + return val.values } func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 { - return keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]}).values + val, _ := keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]}) + return val.values } func (kv KV) OllamaEngineRequired() bool { @@ -140,17 +185,17 @@ type arrayValueTypes interface { *array[string] | *array[float32] | *array[float64] | *array[bool] } -func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) T { +func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) { if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") { key = kv.Architecture() + "." + key } - if val, ok := kv[key]; ok { - return val.(T) + if val, ok := kv[key].(T); ok { + return val, true } - slog.Warn("key not found", "key", key, "default", defaultValue[0]) - return defaultValue[0] + slog.Warn("key with type not found", "key", key, "default", defaultValue[0]) + return defaultValue[0], false } type Tensors struct { @@ -413,11 +458,11 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) { func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) { embedding := f.KV().EmbeddingLength() - heads := f.KV().HeadCount() - headsKV := f.KV().HeadCountKV() + heads := f.KV().HeadCountMax() + headsKV := f.KV().HeadCountKVMax() vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size) - embeddingHeads := f.KV().EmbeddingHeadCount() + embeddingHeads := f.KV().EmbeddingHeadCountMax() embeddingHeadsK := f.KV().EmbeddingHeadCountK() embeddingHeadsV := f.KV().EmbeddingHeadCountV() diff --git a/fs/ggml/ggml_test.go b/fs/ggml/ggml_test.go index c1c1b43b0..225e74841 100644 --- a/fs/ggml/ggml_test.go +++ b/fs/ggml/ggml_test.go @@ -269,3 +269,33 @@ func TestKeyValue(t *testing.T) { t.Errorf("unexpected uint8s (-got +want):\n%s", diff) } } + +func TestHeadCount(t *testing.T) { + valuesArray := []int32{1, 5, 3, 4} + cases := []struct { + kv KV + want uint64 + }{ + { + kv: KV{ + "general.architecture": "abc", + "abc.attention.head_count": &array[int32]{values: valuesArray, size: len(valuesArray)}, + }, + want: uint64(5), + }, + { + kv: KV{ + "general.architecture": "abc", + "abc.attention.head_count": uint32(3), + }, + want: uint64(3), + }, + } + + for _, tt := range cases { + got := tt.kv.HeadCountMax() + if got != tt.want { + t.Errorf("unexpected max value: got=%d want=%d", got, tt.want) + } + } +} diff --git a/llm/memory.go b/llm/memory.go index e05327f79..d029e4d31 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -149,7 +149,12 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin } if graphPartialOffload == 0 { - graphPartialOffload = f.KV().GQA() * kvTotal / 6 + headsKV := f.KV().HeadCountKVMin() + if headsKV == 0 { + headsKV = 1 + } + gqa := f.KV().HeadCountMax() / headsKV + graphPartialOffload = gqa * kvTotal / 6 } if graphFullOffload == 0 { graphFullOffload = graphPartialOffload