Enable array type json
This commit is contained in:
parent
77f709ebd5
commit
198fde82aa
@ -29,7 +29,6 @@ const (
|
|||||||
StateInObjSpace
|
StateInObjSpace
|
||||||
StateInList
|
StateInList
|
||||||
StateInListComma
|
StateInListComma
|
||||||
StateListEnd
|
|
||||||
StateInValue
|
StateInValue
|
||||||
StateInValueEnd
|
StateInValueEnd
|
||||||
StateInListEnd
|
StateInListEnd
|
||||||
@ -63,7 +62,6 @@ var JSONStates = []JSONState{
|
|||||||
StateInObjSpace,
|
StateInObjSpace,
|
||||||
StateInList,
|
StateInList,
|
||||||
StateInListComma,
|
StateInListComma,
|
||||||
StateListEnd,
|
|
||||||
StateInValue,
|
StateInValue,
|
||||||
StateInValueEnd,
|
StateInValueEnd,
|
||||||
StateInListEnd,
|
StateInListEnd,
|
||||||
@ -118,8 +116,6 @@ func (s JSONState) String() string {
|
|||||||
return "StateInListObjectEnd"
|
return "StateInListObjectEnd"
|
||||||
case StateInListComma:
|
case StateInListComma:
|
||||||
return "StateInListComma"
|
return "StateInListComma"
|
||||||
case StateListEnd:
|
|
||||||
return "StateListEnd"
|
|
||||||
case StateInListEnd:
|
case StateInListEnd:
|
||||||
return "StateInListEnd"
|
return "StateInListEnd"
|
||||||
case StateInNewline:
|
case StateInNewline:
|
||||||
|
@ -47,6 +47,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
|||||||
// Connect nodes
|
// Connect nodes
|
||||||
// TODO: if all are single tokens then this can just be connected instead of defining the token
|
// TODO: if all are single tokens then this can just be connected instead of defining the token
|
||||||
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
|
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList]
|
||||||
|
|
||||||
stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||||
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||||
@ -121,7 +122,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
|||||||
|
|
||||||
// list object end
|
// list object end
|
||||||
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
|
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
|
||||||
stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateListEnd]
|
stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||||
|
|
||||||
// bool node
|
// bool node
|
||||||
for _, r := range validBoolRunes {
|
for _, r := range validBoolRunes {
|
||||||
@ -129,8 +130,8 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
|||||||
}
|
}
|
||||||
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
|
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
|
||||||
|
|
||||||
stateToNodeMap[StateListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
stateToNodeMap[StateListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||||
|
|
||||||
stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
||||||
@ -147,7 +148,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
|||||||
func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
|
func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
|
||||||
node.TransitionEdges[','] = stateToNodeMap[StateInComma]
|
node.TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||||
node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
node.TransitionEdges[']'] = stateToNodeMap[StateListEnd]
|
node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||||
}
|
}
|
||||||
|
|
||||||
func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
|
func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
|
||||||
|
@ -58,6 +58,27 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
|||||||
case StateInString:
|
case StateInString:
|
||||||
return s.maskLogits(logits, s.curNode)
|
return s.maskLogits(logits, s.curNode)
|
||||||
|
|
||||||
|
case StateInListEnd:
|
||||||
|
fmt.Println("in list end", s.braceStack)
|
||||||
|
// force finish if no braces left
|
||||||
|
if len(s.braceStack) == 0 {
|
||||||
|
s.curNode = NewPDANode(StateTerminate)
|
||||||
|
for i := range logits {
|
||||||
|
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
||||||
|
logits[i] = 1.0
|
||||||
|
} else {
|
||||||
|
logits[i] = math.NaN()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logits, err := s.maskLogits(logits, s.curNode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return logits, nil
|
||||||
|
|
||||||
case StateInObjectEnd:
|
case StateInObjectEnd:
|
||||||
// force finish if no braces left
|
// force finish if no braces left
|
||||||
if len(s.braceStack) == 0 {
|
if len(s.braceStack) == 0 {
|
||||||
@ -117,11 +138,12 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||||
// fmt.Println("update state", s.curNode.State)
|
fmt.Println("update state", s.curNode.State)
|
||||||
mappedString, err := s.proc.Decode(tokenSlice)
|
mappedString, err := s.proc.Decode(tokenSlice)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
fmt.Println("mappedString", mappedString)
|
||||||
|
|
||||||
// TODO: should force closing for all braces - not doing square yet
|
// TODO: should force closing for all braces - not doing square yet
|
||||||
for _, r := range mappedString {
|
for _, r := range mappedString {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user