This commit is contained in:
ParthSareen 2025-03-25 16:45:27 -07:00
parent 5ec6bb52a0
commit 4450f871db
2 changed files with 25 additions and 12 deletions

View File

@ -582,16 +582,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
// return // return
// } // }
// jsonSampler = nil // jsonSampler = nil
// pythonSampler := sample.NewPythonSampler(s.model.(model.TextProcessor), nil) pythonSampler := &sample.PythonSampler{}
// pythonSampler := &sample.PythonSampler{} functions := []sample.PythonFunction{
// functions := []sample.PythonFunction{ {
// { Name: "add_two_strings",
// Name: "add_two_strings", Args: []string{"s1", "s2"},
// Args: []string{"s1", "s2"}, Types: []string{"string", "string"},
// Types: []string{"string", "string"}, },
// }, }
// } pythonSampler.Init(functions, s.model.(model.TextProcessor))
// pythonSampler.Init(functions, s.model.(model.TextProcessor))
sampler := sample.NewSampler( sampler := sample.NewSampler(
req.Options.Temperature, req.Options.Temperature,
req.Options.TopK, req.Options.TopK,
@ -600,7 +599,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
req.Options.Seed, req.Options.Seed,
grammar, grammar,
nil, nil,
nil, pythonSampler,
// nil,
) )
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{

View File

@ -108,10 +108,13 @@ type PythonSampler struct {
proc model.TextProcessor proc model.TextProcessor
decodedToks []string decodedToks []string
curNode *Node curNode *Node
completed int
functions []PythonFunction
} }
func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error { func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error {
s.proc = proc s.proc = proc
s.functions = functions
decodedToks := make([]string, len(proc.Vocab().Values)) decodedToks := make([]string, len(proc.Vocab().Values))
for i := range proc.Vocab().Values { for i := range proc.Vocab().Values {
token, err := proc.Decode([]int32{int32(i)}) token, err := proc.Decode([]int32{int32(i)})
@ -194,7 +197,7 @@ func (s *PythonSampler) BuildGraph() error {
// String end // String end
s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs] s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate] // s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
// Number // Number
for _, r := range validNumberRunes { for _, r := range validNumberRunes {
s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber] s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber]
@ -237,6 +240,16 @@ func (s *PythonSampler) UpdateState(token int32) error {
if !ok { if !ok {
return fmt.Errorf("invalid token: %q", mappedString) return fmt.Errorf("invalid token: %q", mappedString)
} }
if mappedString == "\"" {
if s.curNode.State == PStateInStringEnd {
s.completed++
}
if s.completed == len(s.functions) {
s.curNode.TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
s.CreateMask(s.curNode)
}
}
s.curNode = nextNode s.curNode = nextNode
fmt.Println("curNode", s.curNode.State) fmt.Println("curNode", s.curNode.State)
for r, node := range s.curNode.TransitionEdges { for r, node := range s.curNode.TransitionEdges {