cmd: handle sigint globally
This change also updates both client.do and client.stream to return ctx.Err(). Previously this error is skipped so canceled contexts are silently ignored
This commit is contained in:
parent
e8d35d0de0
commit
fcfbb06f1b
@ -126,7 +126,8 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
const maxBufferSize = 512 * format.KiloByte
|
const maxBufferSize = 512 * format.KiloByte
|
||||||
@ -189,7 +190,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateResponseFunc is a function that [Client.Generate] invokes every time
|
// GenerateResponseFunc is a function that [Client.Generate] invokes every time
|
||||||
|
38
cmd/cmd.go
38
cmd/cmd.go
@ -15,13 +15,11 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/containerd/console"
|
"github.com/containerd/console"
|
||||||
@ -330,6 +328,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
if err := PullHandler(cmd, []string{name}); err != nil {
|
if err := PullHandler(cmd, []string{name}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
||||||
}
|
}
|
||||||
return info, err
|
return info, err
|
||||||
@ -858,17 +857,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
|||||||
spinner := progress.NewSpinner("")
|
spinner := progress.NewSpinner("")
|
||||||
p.Add("", spinner)
|
p.Add("", spinner)
|
||||||
|
|
||||||
cancelCtx, cancel := context.WithCancel(cmd.Context())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
sigChan := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigChan, syscall.SIGINT)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-sigChan
|
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
|
|
||||||
var state *displayResponseState = &displayResponseState{}
|
var state *displayResponseState = &displayResponseState{}
|
||||||
var latest api.ChatResponse
|
var latest api.ChatResponse
|
||||||
var fullResponse strings.Builder
|
var fullResponse strings.Builder
|
||||||
@ -903,10 +891,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
|||||||
req.KeepAlive = opts.KeepAlive
|
req.KeepAlive = opts.KeepAlive
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := client.Chat(cancelCtx, req, fn); err != nil {
|
if err := client.Chat(cmd.Context(), req, fn); err != nil {
|
||||||
if errors.Is(err, context.Canceled) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -946,17 +931,6 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
|||||||
generateContext = []int{}
|
generateContext = []int{}
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
sigChan := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigChan, syscall.SIGINT)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-sigChan
|
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
|
|
||||||
var state *displayResponseState = &displayResponseState{}
|
var state *displayResponseState = &displayResponseState{}
|
||||||
|
|
||||||
fn := func(response api.GenerateResponse) error {
|
fn := func(response api.GenerateResponse) error {
|
||||||
@ -992,10 +966,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
|||||||
KeepAlive: opts.KeepAlive,
|
KeepAlive: opts.KeepAlive,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := client.Generate(ctx, &request, fn); err != nil {
|
if err := client.Generate(cmd.Context(), &request, fn); err != nil {
|
||||||
if errors.Is(err, context.Canceled) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1017,8 +988,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
|||||||
latest.Summary()
|
latest.Summary()
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
|
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context))
|
||||||
cmd.SetContext(ctx)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
14
main.go
14
main.go
@ -2,6 +2,8 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
@ -9,5 +11,15 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
cobra.CheckErr(cmd.NewCLI().ExecuteContext(context.Background()))
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigChan, os.Interrupt)
|
||||||
|
go func() {
|
||||||
|
<-sigChan
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
cobra.CheckErr(cmd.NewCLI().ExecuteContext(ctx))
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user