diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 9431e9cc1..de4ed4d54 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -53,11 +53,11 @@ func (kv KV) EmbeddingLength() uint64 { } func (kv KV) HeadCount() uint64 { - return uint64(kv.Uint("attention.head_count")) + return uint64(kv.UintOrFirstArrayValue("attention.head_count")) } func (kv KV) HeadCountKV() uint64 { - return uint64(kv.Uint("attention.head_count_kv", 1)) + return uint64(kv.UintOrFirstArrayValue("attention.head_count_kv", 1)) } func (kv KV) EmbeddingHeadCount() uint64 { @@ -104,6 +104,22 @@ func (kv KV) Bool(key string, defaultValue ...bool) bool { return keyValue(kv, key, append(defaultValue, false)...) } +func (kv KV) UintOrFirstArrayValue(key string, defaultValue ...uint32) uint32 { + if v, ok := keyValueUntyped(kv, key); ok { + if a, ok := v.(*array); ok { + signed := a.values[0].(int32) + if signed >= 0 { + return uint32(signed) + } + // TODO(drifkin): indicate unexpected data somehow? + return defaultValue[0] + } else if v, ok := v.(uint32); ok { + return v + } + } + return defaultValue[0] +} + func (kv KV) Strings(key string, defaultValue ...[]string) []string { r := keyValue(kv, key, &array{}) s := make([]string, r.size) @@ -141,11 +157,7 @@ func (kv KV) OllamaEngineRequired() bool { } func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T { - if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") { - key = kv.Architecture() + "." + key - } - - if val, ok := kv[key]; ok { + if val, ok := keyValueUntyped(kv, key); ok { return val.(T) } @@ -153,6 +165,18 @@ func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key s return defaultValue[0] } +func keyValueUntyped(kv KV, key string) (any, bool) { + if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") { + key = kv.Architecture() + "." + key + } + + if val, ok := kv[key]; ok { + return val, true + } + + return nil, false +} + type Tensors struct { items []*Tensor Offset uint64