WIP working with other stuff
This commit is contained in:
parent
e93db4d20e
commit
73098a2973
171
sample/decode.go
Normal file
171
sample/decode.go
Normal file
@ -0,0 +1,171 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// Schema holds a JSON schema.
|
||||
type Schema struct {
|
||||
// Name is the name of the property. For the parent/root property, this
|
||||
// is "root". For child properties, this is the name of the property.
|
||||
Name string `json:"-"`
|
||||
|
||||
// Type is the type of the property.
|
||||
//
|
||||
// TODO: Union types (e.g. make this a []string).
|
||||
Type string
|
||||
|
||||
// PrefixItems is a list of schemas for each item in a tuple. By
|
||||
// default, the tuple is "closed." unless Items is set to true or a
|
||||
// valid Schema.
|
||||
PrefixItems []*Schema
|
||||
|
||||
// Items is the schema for each item in a list.
|
||||
//
|
||||
// If it is missing, or its JSON value is "null" or "false", it is nil.
|
||||
// If the JSON value is "true", it is set to the empty Schema. If the
|
||||
// JSON value is an object, it will be decoded as a Schema.
|
||||
Items *Schema
|
||||
|
||||
// MinItems specifies the minimum number of items allowed in a list.
|
||||
MinItems int
|
||||
|
||||
// MaxItems specifies the maximum number of items allowed in a list.
|
||||
MaxItems int
|
||||
|
||||
// Properties is the schema for each property of an object.
|
||||
Properties []*Schema
|
||||
|
||||
// Format is the format of the property. This is used to validate the
|
||||
// property against a specific format.
|
||||
//
|
||||
// It is the callers responsibility to validate the property against
|
||||
// the format.
|
||||
Format string
|
||||
|
||||
// Minimum specifies the minimum value for numeric properties.
|
||||
Minimum float64
|
||||
|
||||
// Maximum specifies the maximum value for numeric properties.
|
||||
Maximum float64
|
||||
|
||||
// Enum is a list of valid values for the property.
|
||||
Enum []json.RawMessage
|
||||
}
|
||||
|
||||
func (s *Schema) UnmarshalJSON(data []byte) error {
|
||||
type S Schema
|
||||
w := struct {
|
||||
Properties props
|
||||
Items items
|
||||
*S
|
||||
}{
|
||||
S: (*S)(s),
|
||||
}
|
||||
if err := json.Unmarshal(data, &w); err != nil {
|
||||
return err
|
||||
}
|
||||
if w.Items.set {
|
||||
s.Items = &w.Items.Schema
|
||||
}
|
||||
s.Properties = w.Properties
|
||||
return nil
|
||||
}
|
||||
|
||||
type items struct {
|
||||
Schema
|
||||
set bool
|
||||
}
|
||||
|
||||
func (s *items) UnmarshalJSON(data []byte) error {
|
||||
switch b := data[0]; b {
|
||||
case 't':
|
||||
*s = items{set: true}
|
||||
case '{':
|
||||
type I items
|
||||
if err := json.Unmarshal(data, (*I)(s)); err != nil {
|
||||
return err
|
||||
}
|
||||
s.set = true
|
||||
case 'n', 'f':
|
||||
default:
|
||||
return errors.New("invalid Items")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EffectiveType returns the effective type of the schema. If the Type field is
|
||||
// not empty, it is returned; otherwise:
|
||||
//
|
||||
// - If the schema has both Properties and Items, it returns an empty string.
|
||||
// - If the schema has Properties, it returns "object".
|
||||
// - If the schema has Items, it returns "array".
|
||||
// - If the schema has neither Properties nor Items, it returns "value".
|
||||
//
|
||||
// The returned string is never empty.
|
||||
func (d *Schema) EffectiveType() string {
|
||||
if d.Type == "" {
|
||||
if len(d.Properties) > 0 {
|
||||
return "object"
|
||||
}
|
||||
if len(d.PrefixItems) > 0 || d.Items != nil {
|
||||
return "array"
|
||||
}
|
||||
return "value"
|
||||
}
|
||||
return d.Type
|
||||
}
|
||||
|
||||
// props is an ordered list of properties. The order of the properties
|
||||
// is the order in which they were defined in the schema.
|
||||
type props []*Schema
|
||||
|
||||
var _ json.Unmarshaler = (*props)(nil)
|
||||
|
||||
func (v *props) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
if data[0] != '{' {
|
||||
return errors.New("expected object")
|
||||
}
|
||||
|
||||
d := json.NewDecoder(bytes.NewReader(data))
|
||||
|
||||
// TODO(bmizerany): Consider DisallowUnknownFields. Currently, we, like
|
||||
// llama.cpp, ignore unknown fields, which could be lead to unexpected
|
||||
// behavior for clients of this package, since they may not be aware
|
||||
// that "additionalFields", "itemsPrefix", etc, are being ignored.
|
||||
//
|
||||
// For now, just do what llama.cpp does.
|
||||
|
||||
t, err := d.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t != json.Delim('{') {
|
||||
return errors.New("expected object")
|
||||
}
|
||||
for d.More() {
|
||||
// Use the first token (map key) as the property name, then
|
||||
// decode the rest of the object fields into a Schema and
|
||||
// append.
|
||||
t, err := d.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t == json.Delim('}') {
|
||||
return nil
|
||||
}
|
||||
s := &Schema{
|
||||
Name: t.(string),
|
||||
}
|
||||
if err := d.Decode(s); err != nil {
|
||||
return err
|
||||
}
|
||||
*v = append(*v, s)
|
||||
}
|
||||
return nil
|
||||
}
|
@ -30,7 +30,10 @@ const (
|
||||
StateInList
|
||||
StateInListComma
|
||||
StateListEnd
|
||||
StateInValue
|
||||
StateInValueEnd
|
||||
StateInListEnd
|
||||
StateInListObjectEnd
|
||||
StateInNewline
|
||||
StateInNumber
|
||||
StateInNumberEnd
|
||||
@ -38,6 +41,7 @@ const (
|
||||
StateInObjectKeyEnd
|
||||
StateTerminate
|
||||
StateInObjectEnd
|
||||
StateTransitioningToTerminate
|
||||
)
|
||||
|
||||
func (s JSONState) String() string {
|
||||
@ -76,6 +80,8 @@ func (s JSONState) String() string {
|
||||
return "StateInObjSpace"
|
||||
case StateInList:
|
||||
return "StateInList"
|
||||
case StateInListObjectEnd:
|
||||
return "StateInListObjectEnd"
|
||||
case StateInListComma:
|
||||
return "StateInListComma"
|
||||
case StateListEnd:
|
||||
|
@ -6,7 +6,9 @@ import (
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','}
|
||||
// TODO: / should be valid but an escape character
|
||||
|
||||
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
|
||||
|
||||
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
|
||||
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
|
||||
@ -34,6 +36,7 @@ func NewPDANode(state JSONState) *PDANode {
|
||||
func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) {
|
||||
stateToNodeMap := make(map[JSONState]*PDANode)
|
||||
|
||||
// TODO: make this a loop
|
||||
startNode := NewPDANode(StateStart)
|
||||
stateToNodeMap[StateStart] = startNode
|
||||
|
||||
@ -95,6 +98,9 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
||||
intNode := NewPDANode(StateInInt)
|
||||
stateToNodeMap[StateInInt] = intNode
|
||||
|
||||
listObjEndNode := NewPDANode(StateInListObjectEnd)
|
||||
stateToNodeMap[StateInListObjectEnd] = listObjEndNode
|
||||
|
||||
// TODO:
|
||||
// consider adding a node to just point to values, could be good to compute that
|
||||
// mask rather than many different nodes
|
||||
@ -105,108 +111,84 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
||||
|
||||
objNode.TransitionEdges['"'] = objKeyNode
|
||||
objNode.TransitionEdges['\n'] = newlineNode
|
||||
// objNode.TransitionEdges['\t'] = tabNode
|
||||
objNode.TransitionEdges[' '] = spaceObjNode
|
||||
|
||||
//new line
|
||||
newlineNode.TransitionEdges['"'] = objKeyNode
|
||||
newlineNode.TransitionEdges['\t'] = tabNode
|
||||
|
||||
tabNode.TransitionEdges['"'] = objKeyNode
|
||||
// tabNode.TransitionEdges['\t'] = tabNode
|
||||
|
||||
objKeyNode.TransitionEdges[rune(-1)] = objKeyNode
|
||||
objKeyNode.TransitionEdges['"'] = objKeyEndNode
|
||||
|
||||
objKeyEndNode.TransitionEdges[':'] = colonNode
|
||||
objEndNode.TransitionEdges[' '] = spaceNode
|
||||
|
||||
objEndNode.TransitionEdges[','] = commaNode
|
||||
objEndNode.TransitionEdges['}'] = objEndNode
|
||||
|
||||
// where values should be
|
||||
// this could be combined but the probs might change, we're alr doing a skip ahead
|
||||
colonNode.TransitionEdges[' '] = spaceNode
|
||||
colonNode.TransitionEdges['['] = listNode
|
||||
colonNode.TransitionEdges['{'] = objNode
|
||||
addValueConnections(colonNode, stateToNodeMap)
|
||||
|
||||
// Leads to a value
|
||||
spaceNode.TransitionEdges['"'] = stringNode
|
||||
spaceNode.TransitionEdges['['] = listNode
|
||||
spaceNode.TransitionEdges['{'] = objNode
|
||||
|
||||
for _, r := range validNumberRunes {
|
||||
spaceNode.TransitionEdges[r] = numberNode
|
||||
}
|
||||
for _, r := range validBoolRunes {
|
||||
spaceNode.TransitionEdges[r] = boolNode
|
||||
}
|
||||
|
||||
for _, r := range validNullRunes {
|
||||
spaceNode.TransitionEdges[r] = nullNode
|
||||
}
|
||||
addValueConnections(spaceNode, stateToNodeMap)
|
||||
|
||||
// Values
|
||||
// string node
|
||||
stringNode.TransitionEdges[rune(-1)] = stringNode
|
||||
stringNode.TransitionEdges['"'] = stringEndNode
|
||||
|
||||
stringEndNode.TransitionEdges[','] = commaNode
|
||||
stringEndNode.TransitionEdges['}'] = objEndNode
|
||||
stringEndNode.TransitionEdges[']'] = listEndNode
|
||||
// String end node
|
||||
addEnds(stringEndNode, stateToNodeMap)
|
||||
|
||||
// TODO: add counters for allowable number of decimals, e, E, etc
|
||||
// number node
|
||||
for _, r := range validNumberRunes {
|
||||
numberNode.TransitionEdges[r] = numberNode
|
||||
}
|
||||
numberNode.TransitionEdges[','] = commaNode
|
||||
numberNode.TransitionEdges['}'] = objEndNode
|
||||
numberNode.TransitionEdges[']'] = listEndNode
|
||||
|
||||
for _, r := range validBoolRunes {
|
||||
boolNode.TransitionEdges[r] = boolNode
|
||||
}
|
||||
|
||||
// list node
|
||||
listNode.TransitionEdges[','] = commaNode
|
||||
listNode.TransitionEdges['"'] = stringNode
|
||||
// squash states to a value
|
||||
for _, r := range validNumberRunes {
|
||||
listNode.TransitionEdges[r] = numberNode
|
||||
}
|
||||
for _, r := range validBoolRunes {
|
||||
listNode.TransitionEdges[r] = boolNode
|
||||
}
|
||||
for _, r := range validNullRunes {
|
||||
listNode.TransitionEdges[r] = nullNode
|
||||
}
|
||||
|
||||
// null node
|
||||
for _, r := range validNullRunes {
|
||||
nullNode.TransitionEdges[r] = nullNode
|
||||
}
|
||||
nullNode.TransitionEdges[','] = commaNode
|
||||
nullNode.TransitionEdges['}'] = objEndNode
|
||||
nullNode.TransitionEdges[']'] = listEndNode
|
||||
|
||||
// list comma
|
||||
// should point to values
|
||||
listCommaNode.TransitionEdges['"'] = stringNode
|
||||
listCommaNode.TransitionEdges[' '] = listCommaNode
|
||||
listCommaNode.TransitionEdges['{'] = objNode
|
||||
listCommaNode.TransitionEdges['\n'] = newlineNode
|
||||
|
||||
for _, r := range validNumberRunes {
|
||||
listCommaNode.TransitionEdges[r] = numberNode
|
||||
}
|
||||
for _, r := range validBoolRunes {
|
||||
listCommaNode.TransitionEdges[r] = boolNode
|
||||
}
|
||||
for _, r := range validNullRunes {
|
||||
listCommaNode.TransitionEdges[r] = nullNode
|
||||
}
|
||||
addEnds(numberNode, stateToNodeMap)
|
||||
|
||||
// bool node
|
||||
for _, r := range validBoolRunes {
|
||||
boolNode.TransitionEdges[r] = boolNode
|
||||
}
|
||||
boolNode.TransitionEdges['}'] = objEndNode
|
||||
boolNode.TransitionEdges[']'] = listEndNode
|
||||
boolNode.TransitionEdges[','] = commaNode
|
||||
addEnds(boolNode, stateToNodeMap)
|
||||
|
||||
// list node
|
||||
listNode.TransitionEdges[','] = commaNode
|
||||
listNode.TransitionEdges['{'] = objNode
|
||||
listNode.TransitionEdges[' '] = listNode
|
||||
listNode.TransitionEdges['\n'] = listNode
|
||||
addValueConnections(listNode, stateToNodeMap)
|
||||
|
||||
// null node
|
||||
for _, r := range validNullRunes {
|
||||
nullNode.TransitionEdges[r] = nullNode
|
||||
}
|
||||
addEnds(nullNode, stateToNodeMap)
|
||||
|
||||
// list comma
|
||||
// should point to values
|
||||
listCommaNode.TransitionEdges[' '] = listCommaNode
|
||||
listCommaNode.TransitionEdges['{'] = objNode
|
||||
listCommaNode.TransitionEdges['\n'] = newlineNode
|
||||
addValueConnections(listCommaNode, stateToNodeMap)
|
||||
|
||||
// list object end
|
||||
listObjEndNode.TransitionEdges[','] = listCommaNode
|
||||
listObjEndNode.TransitionEdges[']'] = listEndNode
|
||||
|
||||
// bool node
|
||||
for _, r := range validBoolRunes {
|
||||
boolNode.TransitionEdges[r] = boolNode
|
||||
}
|
||||
addEnds(boolNode, stateToNodeMap)
|
||||
|
||||
listEndNode.TransitionEdges['}'] = objEndNode
|
||||
listEndNode.TransitionEdges[','] = commaNode
|
||||
@ -218,10 +200,27 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
||||
commaNode.TransitionEdges[' '] = spaceObjNode
|
||||
|
||||
spaceObjNode.TransitionEdges['"'] = objKeyNode
|
||||
spaceObjNode.TransitionEdges['\n'] = newlineNode
|
||||
|
||||
return startNode, stateToNodeMap, nil
|
||||
}
|
||||
|
||||
func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
|
||||
node.TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||
node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
node.TransitionEdges[']'] = stateToNodeMap[StateListEnd]
|
||||
}
|
||||
|
||||
func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
|
||||
node.TransitionEdges['"'] = stateToNodeMap[StateInString]
|
||||
for _, r := range validNumberRunes {
|
||||
node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
||||
}
|
||||
node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
|
||||
node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
|
||||
node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
|
||||
}
|
||||
|
||||
func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
|
||||
|
||||
vocab := proc.GetVocabulary()
|
||||
@ -240,7 +239,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
|
||||
for i := range vocab.Values {
|
||||
token := decodedToks[i]
|
||||
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||
if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" {
|
||||
if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
|
||||
continue
|
||||
}
|
||||
valid := true
|
||||
@ -263,6 +262,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
|
||||
return nil
|
||||
}
|
||||
|
||||
// garbage interface plz fix
|
||||
func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
|
||||
if consumedSpecialRunes[r] {
|
||||
return false, nil, nil
|
||||
@ -281,7 +281,6 @@ func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (
|
||||
if curNode.State == nextNode.State {
|
||||
return false, nil, nil
|
||||
}
|
||||
// fmt.Println("special rune", r, "consumed")
|
||||
consumedSpecialRunes[r] = true
|
||||
}
|
||||
return true, nextNode, nil
|
||||
|
@ -9,6 +9,8 @@ import (
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
// TODO: safety in case of invalid json
|
||||
// TODO: interfaces to cleanup with return values
|
||||
type PushdownSampler struct {
|
||||
// stateful
|
||||
curNode *PDANode
|
||||
@ -18,6 +20,7 @@ type PushdownSampler struct {
|
||||
stateCounter uint32
|
||||
}
|
||||
|
||||
// graph should be built once and reused per tokenizer
|
||||
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
start := time.Now()
|
||||
|
||||
@ -39,14 +42,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
|
||||
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||
fmt.Printf("Graph build time = %v\n", time.Since(start))
|
||||
// for id, node := range stateToNodeMap[StateInComma].MaskTokenIDToNode {
|
||||
// token, err := proc.Decode([]int32{int32(id)})
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
// fmt.Println("id", id, "node", node, "token", token)
|
||||
// }
|
||||
// time.Sleep(10 * time.Second)
|
||||
|
||||
return &PushdownSampler{
|
||||
curNode: startNode,
|
||||
proc: proc,
|
||||
@ -57,9 +53,11 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
}
|
||||
|
||||
func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
fmt.Println("sample:", s.curNode.State)
|
||||
|
||||
// fmt.Println(">>> sample:", s.curNode.State)
|
||||
switch s.curNode.State {
|
||||
case StateInString:
|
||||
return s.maskLogits(logits, s.curNode)
|
||||
|
||||
case StateInObjectEnd:
|
||||
// force finish if no braces left
|
||||
if len(s.braceStack) == 0 {
|
||||
@ -73,24 +71,24 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
valid, err := s.proc.Encode("}")
|
||||
|
||||
peek := s.braceStack[len(s.braceStack)-1]
|
||||
if peek == rune('[') {
|
||||
s.curNode = s.stateToNodeMap[StateInListObjectEnd]
|
||||
// fmt.Println("switching to list object end", s.curNode.State)
|
||||
}
|
||||
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range logits {
|
||||
for _, token := range valid {
|
||||
if i != int(token) {
|
||||
logits[i] = math.NaN()
|
||||
}
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
case StateInComma:
|
||||
peek := s.braceStack[len(s.braceStack)-1]
|
||||
if peek == rune('[') {
|
||||
s.curNode = s.stateToNodeMap[StateInListComma]
|
||||
fmt.Println("switching to list comma", s.curNode.State)
|
||||
// fmt.Println("switching to list comma", s.curNode.State)
|
||||
}
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
if err != nil {
|
||||
@ -109,7 +107,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
return logits, nil
|
||||
|
||||
default:
|
||||
fmt.Println("masking logits current state", s.curNode.State)
|
||||
// fmt.Println("masking logits current state", s.curNode.State)
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -119,54 +117,48 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
}
|
||||
|
||||
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||
fmt.Println("update state", s.curNode.State)
|
||||
|
||||
// TODO: need to handle end states and entering object case, and list case
|
||||
if s.curNode.State == StateInObjectEnd {
|
||||
fmt.Println("in object end")
|
||||
if len(s.braceStack) > 0 {
|
||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||
return nil
|
||||
}
|
||||
s.curNode = NewPDANode(StateTerminate)
|
||||
// TODO: return here?
|
||||
}
|
||||
// need this cause there could be multiple transitions
|
||||
// fmt.Println("update state", s.curNode.State)
|
||||
mappedString, err := s.proc.Decode(tokenSlice)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: should force closing for all braces
|
||||
|
||||
// TODO: should force closing for all braces - not doing square yet
|
||||
for _, r := range mappedString {
|
||||
if r == rune('{') {
|
||||
s.braceStack = append(s.braceStack, r)
|
||||
// fmt.Println("pushing { brace stack", r)
|
||||
}
|
||||
if r == rune('[') {
|
||||
s.braceStack = append(s.braceStack, r)
|
||||
// fmt.Println("pushing [ brace stack", r)
|
||||
}
|
||||
if r == rune('}') {
|
||||
if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('{') {
|
||||
return fmt.Errorf("unmatched closing brace")
|
||||
top := s.braceStack[len(s.braceStack)-1]
|
||||
if len(s.braceStack) == 0 || top != rune('{') {
|
||||
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
|
||||
}
|
||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||
fmt.Println("popping brace stack", s.braceStack)
|
||||
// fmt.Println("popping { brace stack", top)
|
||||
}
|
||||
|
||||
if r == rune(']') {
|
||||
if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('[') {
|
||||
return fmt.Errorf("unmatched closing brace")
|
||||
top := s.braceStack[len(s.braceStack)-1]
|
||||
if len(s.braceStack) == 0 || top != rune('[') {
|
||||
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
|
||||
}
|
||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||
fmt.Println("popping brace stack", s.braceStack)
|
||||
// fmt.Println("popping [ brace stack", top)
|
||||
}
|
||||
}
|
||||
|
||||
for _, tokenID := range tokenSlice {
|
||||
// transition to the next node
|
||||
nextNodeState, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid token: %q", mappedString)
|
||||
}
|
||||
fmt.Println("transitioning to", nextNodeState)
|
||||
// fmt.Println("transitioning to", nextNodeState)
|
||||
|
||||
// TODO: add a penalty for staying in the same state too long
|
||||
if nextNodeState == s.curNode.State {
|
||||
|
98
sample/structured_outputs.go
Normal file
98
sample/structured_outputs.go
Normal file
@ -0,0 +1,98 @@
|
||||
package sample
|
||||
|
||||
import "github.com/ollama/ollama/model"
|
||||
|
||||
type StructuredOutput struct {
|
||||
schema *Schema
|
||||
}
|
||||
|
||||
func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *PDANode {
|
||||
// _, stateToNodeMap, err := BuildGraph(proc)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// func constrainGraph(graph *PDANode, schema *Schema) *PDANode {
|
||||
// // If no schema constraints, return original graph node
|
||||
// if schema == nil {
|
||||
// return graph
|
||||
// }
|
||||
|
||||
// // Create a new node with same state
|
||||
// constrainedNode := NewPDANode(graph.State)
|
||||
|
||||
// // Copy over existing transitions and masks
|
||||
// constrainedNode.TransitionEdges = make(map[rune]*PDANode)
|
||||
// for r, node := range graph.TransitionEdges {
|
||||
// constrainedNode.TransitionEdges[r] = node
|
||||
// }
|
||||
// constrainedNode.MaskTokenIDToNode = graph.MaskTokenIDToNode
|
||||
|
||||
// // Apply schema constraints based on type
|
||||
// switch schema.EffectiveType() {
|
||||
// case "object":
|
||||
// // Only allow defined property names in object keys
|
||||
// if graph.State == StateInObjectKey {
|
||||
// // TODO: Add property name validation
|
||||
// }
|
||||
|
||||
// // Constrain property values based on schema
|
||||
// if graph.State == StateInColon || graph.State == StateInSpace {
|
||||
// // Clear transitions to only allow valid types
|
||||
// constrainedNode.TransitionEdges = make(map[rune]*PDANode)
|
||||
|
||||
// // Add transitions based on property schemas
|
||||
// for _, prop := range schema.Properties {
|
||||
// switch prop.EffectiveType() {
|
||||
// case "object":
|
||||
// if objNode, ok := graph.TransitionEdges['{']; ok {
|
||||
// constrainedNode.TransitionEdges['{'] = constrainGraph(objNode, prop)
|
||||
// }
|
||||
// case "array":
|
||||
// if arrNode, ok := graph.TransitionEdges['[']; ok {
|
||||
// constrainedNode.TransitionEdges['['] = constrainGraph(arrNode, prop)
|
||||
// }
|
||||
// case "string":
|
||||
// if strNode, ok := graph.TransitionEdges['"']; ok {
|
||||
// constrainedNode.TransitionEdges['"'] = constrainGraph(strNode, prop)
|
||||
// }
|
||||
// case "number":
|
||||
// for _, r := range validNumberRunes {
|
||||
// if numNode, ok := graph.TransitionEdges[r]; ok {
|
||||
// constrainedNode.TransitionEdges[r] = constrainGraph(numNode, prop)
|
||||
// }
|
||||
// }
|
||||
// case "integer":
|
||||
// for _, r := range validIntRunes {
|
||||
// if intNode, ok := graph.TransitionEdges[r]; ok {
|
||||
// constrainedNode.TransitionEdges[r] = constrainGraph(intNode, prop)
|
||||
// }
|
||||
// }
|
||||
// case "boolean":
|
||||
// for _, r := range []rune{'t', 'f'} {
|
||||
// if boolNode, ok := graph.TransitionEdges[r]; ok {
|
||||
// constrainedNode.TransitionEdges[r] = constrainGraph(boolNode, prop)
|
||||
// }
|
||||
// }
|
||||
// case "null":
|
||||
// if nullNode, ok := graph.TransitionEdges['n']; ok {
|
||||
// constrainedNode.TransitionEdges['n'] = constrainGraph(nullNode, prop)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// case "array":
|
||||
// // Constrain array items based on schema
|
||||
// if schema.Items != nil {
|
||||
// for r, node := range graph.TransitionEdges {
|
||||
// constrainedNode.TransitionEdges[r] = constrainGraph(node, schema.Items)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// return constrainedNode
|
||||
// }
|
Loading…
x
Reference in New Issue
Block a user