Enable array type json

This commit is contained in:
ParthSareen 2025-01-29 14:54:39 -08:00
parent 77f709ebd5
commit 198fde82aa
3 changed files with 28 additions and 9 deletions

View File

@ -29,7 +29,6 @@ const (
StateInObjSpace StateInObjSpace
StateInList StateInList
StateInListComma StateInListComma
StateListEnd
StateInValue StateInValue
StateInValueEnd StateInValueEnd
StateInListEnd StateInListEnd
@ -63,7 +62,6 @@ var JSONStates = []JSONState{
StateInObjSpace, StateInObjSpace,
StateInList, StateInList,
StateInListComma, StateInListComma,
StateListEnd,
StateInValue, StateInValue,
StateInValueEnd, StateInValueEnd,
StateInListEnd, StateInListEnd,
@ -118,8 +116,6 @@ func (s JSONState) String() string {
return "StateInListObjectEnd" return "StateInListObjectEnd"
case StateInListComma: case StateInListComma:
return "StateInListComma" return "StateInListComma"
case StateListEnd:
return "StateListEnd"
case StateInListEnd: case StateInListEnd:
return "StateInListEnd" return "StateInListEnd"
case StateInNewline: case StateInNewline:

View File

@ -47,6 +47,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
// Connect nodes // Connect nodes
// TODO: if all are single tokens then this can just be connected instead of defining the token // TODO: if all are single tokens then this can just be connected instead of defining the token
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList]
stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
@ -121,7 +122,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
// list object end // list object end
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma] stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateListEnd] stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
// bool node // bool node
for _, r := range validBoolRunes { for _, r := range validBoolRunes {
@ -129,8 +130,8 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
} }
addEnds(stateToNodeMap[StateInBool], stateToNodeMap) addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
stateToNodeMap[StateListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma] stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInList] stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
@ -147,7 +148,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) { func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
node.TransitionEdges[','] = stateToNodeMap[StateInComma] node.TransitionEdges[','] = stateToNodeMap[StateInComma]
node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
node.TransitionEdges[']'] = stateToNodeMap[StateListEnd] node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
} }
func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) { func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {

View File

@ -58,6 +58,27 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
case StateInString: case StateInString:
return s.maskLogits(logits, s.curNode) return s.maskLogits(logits, s.curNode)
case StateInListEnd:
fmt.Println("in list end", s.braceStack)
// force finish if no braces left
if len(s.braceStack) == 0 {
s.curNode = NewPDANode(StateTerminate)
for i := range logits {
if s.proc.Is(uint32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = math.NaN()
}
}
return logits, nil
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
case StateInObjectEnd: case StateInObjectEnd:
// force finish if no braces left // force finish if no braces left
if len(s.braceStack) == 0 { if len(s.braceStack) == 0 {
@ -117,11 +138,12 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
} }
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
// fmt.Println("update state", s.curNode.State) fmt.Println("update state", s.curNode.State)
mappedString, err := s.proc.Decode(tokenSlice) mappedString, err := s.proc.Decode(tokenSlice)
if err != nil { if err != nil {
return err return err
} }
fmt.Println("mappedString", mappedString)
// TODO: should force closing for all braces - not doing square yet // TODO: should force closing for all braces - not doing square yet
for _, r := range mappedString { for _, r := range mappedString {