fix: fixes a memory leak in bfloat16 package
This change vendors in the bfloat16 package from github.com/d4l3k/go-bfloat16/ and fixes a memory leak which was being caused by using unsafe pointers instead of the math package.
This commit is contained in:
parent
021dcf089d
commit
c75b428249
@ -11,9 +11,10 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/d4l3k/go-bfloat16"
|
|
||||||
"github.com/x448/float16"
|
"github.com/x448/float16"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/types/bfloat16"
|
||||||
)
|
)
|
||||||
|
|
||||||
type safetensorMetadata struct {
|
type safetensorMetadata struct {
|
||||||
|
1
go.mod
1
go.mod
@ -16,7 +16,6 @@ require (
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/agnivade/levenshtein v1.1.1
|
github.com/agnivade/levenshtein v1.1.1
|
||||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
|
||||||
github.com/dlclark/regexp2 v1.11.4
|
github.com/dlclark/regexp2 v1.11.4
|
||||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||||
github.com/google/go-cmp v0.6.0
|
github.com/google/go-cmp v0.6.0
|
||||||
|
2
go.sum
2
go.sum
@ -35,8 +35,6 @@ github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARu
|
|||||||
github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U=
|
github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U=
|
||||||
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 h1:cBzrdJPAFBsgCrDPnZxlp1dF2+k4r1kVpD7+1S1PVjY=
|
|
||||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLcxEuYUlAd/EXyjc/v55nd3+47YAgWbSXVxPrNI=
|
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
21
types/bfloat16/LICENSE
Normal file
21
types/bfloat16/LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2021 Tristan Rice
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
57
types/bfloat16/bfloat16.go
Normal file
57
types/bfloat16/bfloat16.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
// Vendored code from https://github.com/d4l3k/go-bfloat16
|
||||||
|
// unsafe pointer replaced by "math"
|
||||||
|
package bfloat16
|
||||||
|
|
||||||
|
import "math"
|
||||||
|
|
||||||
|
type BF16 uint16
|
||||||
|
|
||||||
|
func FromBytes(buf []byte) BF16 {
|
||||||
|
return BF16(uint16(buf[0]) + uint16(buf[1])<<8)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ToBytes(b BF16) []byte {
|
||||||
|
return []byte{byte(b & 0xFF), byte(b >> 8)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Decode(buf []byte) []BF16 {
|
||||||
|
var out []BF16
|
||||||
|
for i := 0; i < len(buf); i += 2 {
|
||||||
|
out = append(out, FromBytes(buf[i:]))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func Encode(f []BF16) []byte {
|
||||||
|
var out []byte
|
||||||
|
for _, a := range f {
|
||||||
|
out = append(out, ToBytes(a)...)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func DecodeFloat32(buf []byte) []float32 {
|
||||||
|
var out []float32
|
||||||
|
for i := 0; i < len(buf); i += 2 {
|
||||||
|
out = append(out, ToFloat32(FromBytes(buf[i:])))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeFloat32(f []float32) []byte {
|
||||||
|
var out []byte
|
||||||
|
for _, a := range f {
|
||||||
|
out = append(out, ToBytes(FromFloat32(a))...)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func ToFloat32(b BF16) float32 {
|
||||||
|
u32 := uint32(b) << 16
|
||||||
|
return math.Float32frombits(u32)
|
||||||
|
}
|
||||||
|
|
||||||
|
func FromFloat32(f float32) BF16 {
|
||||||
|
u32 := math.Float32bits(f)
|
||||||
|
return BF16(u32 >> 16)
|
||||||
|
}
|
53
types/bfloat16/bfloat16_test.go
Normal file
53
types/bfloat16/bfloat16_test.go
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
package bfloat16
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func randomBytes(n int) []byte {
|
||||||
|
out := make([]byte, n)
|
||||||
|
if _, err := rand.Read(out); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeDecode(t *testing.T) {
|
||||||
|
b := randomBytes(1024)
|
||||||
|
bf16 := Decode(b)
|
||||||
|
out := Encode(bf16)
|
||||||
|
if !reflect.DeepEqual(b, out) {
|
||||||
|
t.Fatalf("%+v != %+v", b, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeDecodeFloat32(t *testing.T) {
|
||||||
|
b := randomBytes(1024)
|
||||||
|
bf16 := DecodeFloat32(b)
|
||||||
|
out := EncodeFloat32(bf16)
|
||||||
|
if !reflect.DeepEqual(b, out) {
|
||||||
|
t.Fatalf("%+v != %+v", b, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBasicFloat32(t *testing.T) {
|
||||||
|
var in float32 = 1.0
|
||||||
|
out := ToFloat32(FromFloat32(in))
|
||||||
|
if !reflect.DeepEqual(in, out) {
|
||||||
|
t.Fatalf("%+v != %+v", in, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComplexFloat32(t *testing.T) {
|
||||||
|
var in float32 = 123456789123456789.123456789
|
||||||
|
var want float32 = 123286039799267328.0
|
||||||
|
out := ToFloat32(FromFloat32(in))
|
||||||
|
if in == out {
|
||||||
|
t.Fatalf("no loss of precision")
|
||||||
|
}
|
||||||
|
if out != want {
|
||||||
|
t.Fatalf("%.16f != %.16f", want, out)
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user