diff --git a/convert/reader_safetensors.go b/convert/reader_safetensors.go index b21d219c2..31b81e62f 100644 --- a/convert/reader_safetensors.go +++ b/convert/reader_safetensors.go @@ -11,9 +11,10 @@ import ( "slices" "strings" - "github.com/d4l3k/go-bfloat16" "github.com/x448/float16" "golang.org/x/exp/maps" + + "github.com/ollama/ollama/types/bfloat16" ) type safetensorMetadata struct { diff --git a/go.mod b/go.mod index cc5789005..9ebf4f1c3 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,6 @@ require ( require ( github.com/agnivade/levenshtein v1.1.1 - github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 github.com/dlclark/regexp2 v1.11.4 github.com/emirpasic/gods/v2 v2.0.0-alpha github.com/google/go-cmp v0.6.0 diff --git a/go.sum b/go.sum index 0ab97b909..419c8853a 100644 --- a/go.sum +++ b/go.sum @@ -35,8 +35,6 @@ github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARu github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 h1:cBzrdJPAFBsgCrDPnZxlp1dF2+k4r1kVpD7+1S1PVjY= -github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLcxEuYUlAd/EXyjc/v55nd3+47YAgWbSXVxPrNI= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/types/bfloat16/LICENSE b/types/bfloat16/LICENSE new file mode 100644 index 000000000..ec7e4939e --- /dev/null +++ b/types/bfloat16/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Tristan Rice + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/types/bfloat16/bfloat16.go b/types/bfloat16/bfloat16.go new file mode 100644 index 000000000..070302d73 --- /dev/null +++ b/types/bfloat16/bfloat16.go @@ -0,0 +1,57 @@ +// Vendored code from https://github.com/d4l3k/go-bfloat16 +// unsafe pointer replaced by "math" +package bfloat16 + +import "math" + +type BF16 uint16 + +func FromBytes(buf []byte) BF16 { + return BF16(uint16(buf[0]) + uint16(buf[1])<<8) +} + +func ToBytes(b BF16) []byte { + return []byte{byte(b & 0xFF), byte(b >> 8)} +} + +func Decode(buf []byte) []BF16 { + var out []BF16 + for i := 0; i < len(buf); i += 2 { + out = append(out, FromBytes(buf[i:])) + } + return out +} + +func Encode(f []BF16) []byte { + var out []byte + for _, a := range f { + out = append(out, ToBytes(a)...) + } + return out +} + +func DecodeFloat32(buf []byte) []float32 { + var out []float32 + for i := 0; i < len(buf); i += 2 { + out = append(out, ToFloat32(FromBytes(buf[i:]))) + } + return out +} + +func EncodeFloat32(f []float32) []byte { + var out []byte + for _, a := range f { + out = append(out, ToBytes(FromFloat32(a))...) + } + return out +} + +func ToFloat32(b BF16) float32 { + u32 := uint32(b) << 16 + return math.Float32frombits(u32) +} + +func FromFloat32(f float32) BF16 { + u32 := math.Float32bits(f) + return BF16(u32 >> 16) +} diff --git a/types/bfloat16/bfloat16_test.go b/types/bfloat16/bfloat16_test.go new file mode 100644 index 000000000..aa7257393 --- /dev/null +++ b/types/bfloat16/bfloat16_test.go @@ -0,0 +1,53 @@ +package bfloat16 + +import ( + "crypto/rand" + "reflect" + "testing" +) + +func randomBytes(n int) []byte { + out := make([]byte, n) + if _, err := rand.Read(out); err != nil { + panic(err) + } + return out +} + +func TestEncodeDecode(t *testing.T) { + b := randomBytes(1024) + bf16 := Decode(b) + out := Encode(bf16) + if !reflect.DeepEqual(b, out) { + t.Fatalf("%+v != %+v", b, out) + } +} + +func TestEncodeDecodeFloat32(t *testing.T) { + b := randomBytes(1024) + bf16 := DecodeFloat32(b) + out := EncodeFloat32(bf16) + if !reflect.DeepEqual(b, out) { + t.Fatalf("%+v != %+v", b, out) + } +} + +func TestBasicFloat32(t *testing.T) { + var in float32 = 1.0 + out := ToFloat32(FromFloat32(in)) + if !reflect.DeepEqual(in, out) { + t.Fatalf("%+v != %+v", in, out) + } +} + +func TestComplexFloat32(t *testing.T) { + var in float32 = 123456789123456789.123456789 + var want float32 = 123286039799267328.0 + out := ToFloat32(FromFloat32(in)) + if in == out { + t.Fatalf("no loss of precision") + } + if out != want { + t.Fatalf("%.16f != %.16f", want, out) + } +}