From c71698426c39caeffe7b622fccbb6c39595aab70 Mon Sep 17 00:00:00 2001 From: Roy Han Date: Mon, 24 Jun 2024 11:09:08 -0700 Subject: [PATCH] Separate Rounding Functions --- cmd/cmd.go | 4 ++-- format/format.go | 36 +++++++++++++++++++++++++++++------- format/format_test.go | 32 ++++++++++++++++++++++++++++++-- server/images.go | 2 +- 4 files changed, 62 insertions(+), 12 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 3ac4a3e15..577a461a0 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -656,7 +656,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error { modelData := [][]string{ {"arch", arch}, - {"parameters", format.HumanNumber(uint64(resp.ModelInfo["general.parameter_count"].(float64)))}, + {"parameters", format.Parameters(uint64(resp.ModelInfo["general.parameter_count"].(float64)))}, {"quantization", resp.Details.QuantizationLevel}, {"context length", fmt.Sprintf("%v", resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))}, {"embedding length", fmt.Sprintf("%v", resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64))}, @@ -670,7 +670,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error { if resp.ProjectorInfo != nil { projectorData := [][]string{ {"arch", "clip"}, - {"parameters", format.HumanNumber(uint64(resp.ProjectorInfo["general.parameter_count"].(float64)))}, + {"parameters", format.Parameters(uint64(resp.ProjectorInfo["general.parameter_count"].(float64)))}, {"projector type", resp.ProjectorInfo["clip.projector_type"].(string)}, {"embedding length", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.embedding_length"].(float64))}, {"projection dimensionality", fmt.Sprintf("%v", resp.ProjectorInfo["clip.vision.projection_dim"].(float64))}, diff --git a/format/format.go b/format/format.go index b48962c97..65bb17ed7 100644 --- a/format/format.go +++ b/format/format.go @@ -2,16 +2,38 @@ package format import ( "fmt" + "math" ) -func HumanNumber(b uint64) string { - const ( - Thousand = 1000 - Million = Thousand * 1000 - Billion = Million * 1000 - Trillion = Billion * 1000 - ) +const ( + Thousand = 1000 + Million = Thousand * 1000 + Billion = Million * 1000 + Trillion = Billion * 1000 +) +func RoundedParameter(b uint64) string { + switch { + case b >= Billion: + number := float64(b) / Billion + if number == math.Floor(number) { + return fmt.Sprintf("%.0fB", number) // no decimals if whole number + } + return fmt.Sprintf("%.1fB", number) // one decimal if not a whole number + case b >= Million: + number := float64(b) / Million + if number == math.Floor(number) { + return fmt.Sprintf("%.0fM", number) // no decimals if whole number + } + return fmt.Sprintf("%.2fM", number) // two decimals if not a whole number + case b >= Thousand: + return fmt.Sprintf("%.0fK", float64(b)/Thousand) + default: + return fmt.Sprintf("%d", b) + } +} + +func Parameters(b uint64) string { switch { case b >= Trillion: number := float64(b) / Trillion diff --git a/format/format_test.go b/format/format_test.go index ec4174f57..1f7b29cbc 100644 --- a/format/format_test.go +++ b/format/format_test.go @@ -4,7 +4,35 @@ import ( "testing" ) -func TestHumanNumber(t *testing.T) { +func TestRoundedParameter(t *testing.T) { + type testCase struct { + input uint64 + expected string + } + + testCases := []testCase{ + {0, "0"}, + {1000000, "1M"}, + {125000000, "125M"}, + {500500000, "500.50M"}, + {500550000, "500.55M"}, + {1000000000, "1B"}, + {2800000000, "2.8B"}, + {2850000000, "2.9B"}, + {1000000000000, "1000B"}, + } + + for _, tc := range testCases { + t.Run(tc.expected, func(t *testing.T) { + result := RoundedParameter(tc.input) + if result != tc.expected { + t.Errorf("Expected %s, got %s", tc.expected, result) + } + }) + } +} + +func TestParameters(t *testing.T) { type testCase struct { input uint64 expected string @@ -23,7 +51,7 @@ func TestHumanNumber(t *testing.T) { for _, tc := range testCases { t.Run(tc.expected, func(t *testing.T) { - result := HumanNumber(tc.input) + result := Parameters(tc.input) if result != tc.expected { t.Errorf("Expected %s, got %s", tc.expected, result) } diff --git a/server/images.go b/server/images.go index 53a957715..67c0898da 100644 --- a/server/images.go +++ b/server/images.go @@ -431,7 +431,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio if baseLayer.GGML != nil { config.ModelFormat = cmp.Or(config.ModelFormat, baseLayer.GGML.Name()) config.ModelFamily = cmp.Or(config.ModelFamily, baseLayer.GGML.KV().Architecture()) - config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount())) + config.ModelType = cmp.Or(config.ModelType, format.RoundedParameter(baseLayer.GGML.KV().ParameterCount())) config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType().String()) config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture()) }