From fcfbb06f1bd3f9bda2641d9e7c34e83a5e73d257 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 19 Feb 2025 09:48:13 -0800 Subject: [PATCH] cmd: handle sigint globally This change also updates both client.do and client.stream to return ctx.Err(). Previously this error is skipped so canceled contexts are silently ignored --- api/client.go | 5 +++-- cmd/cmd.go | 38 ++++---------------------------------- main.go | 14 +++++++++++++- 3 files changed, 20 insertions(+), 37 deletions(-) diff --git a/api/client.go b/api/client.go index 4688d4d13..a51c75dfa 100644 --- a/api/client.go +++ b/api/client.go @@ -126,7 +126,8 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData return err } } - return nil + + return ctx.Err() } const maxBufferSize = 512 * format.KiloByte @@ -189,7 +190,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f } } - return nil + return ctx.Err() } // GenerateResponseFunc is a function that [Client.Generate] invokes every time diff --git a/cmd/cmd.go b/cmd/cmd.go index 80ece4c60..205df7b0b 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -15,13 +15,11 @@ import ( "net" "net/http" "os" - "os/signal" "path/filepath" "runtime" "strconv" "strings" "sync/atomic" - "syscall" "time" "github.com/containerd/console" @@ -330,6 +328,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { if err := PullHandler(cmd, []string{name}); err != nil { return nil, err } + return client.Show(cmd.Context(), &api.ShowRequest{Name: name}) } return info, err @@ -858,17 +857,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { spinner := progress.NewSpinner("") p.Add("", spinner) - cancelCtx, cancel := context.WithCancel(cmd.Context()) - defer cancel() - - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT) - - go func() { - <-sigChan - cancel() - }() - var state *displayResponseState = &displayResponseState{} var latest api.ChatResponse var fullResponse strings.Builder @@ -903,10 +891,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { req.KeepAlive = opts.KeepAlive } - if err := client.Chat(cancelCtx, req, fn); err != nil { - if errors.Is(err, context.Canceled) { - return nil, nil - } + if err := client.Chat(cmd.Context(), req, fn); err != nil { return nil, err } @@ -946,17 +931,6 @@ func generate(cmd *cobra.Command, opts runOptions) error { generateContext = []int{} } - ctx, cancel := context.WithCancel(cmd.Context()) - defer cancel() - - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT) - - go func() { - <-sigChan - cancel() - }() - var state *displayResponseState = &displayResponseState{} fn := func(response api.GenerateResponse) error { @@ -992,10 +966,7 @@ func generate(cmd *cobra.Command, opts runOptions) error { KeepAlive: opts.KeepAlive, } - if err := client.Generate(ctx, &request, fn); err != nil { - if errors.Is(err, context.Canceled) { - return nil - } + if err := client.Generate(cmd.Context(), &request, fn); err != nil { return err } @@ -1017,8 +988,7 @@ func generate(cmd *cobra.Command, opts runOptions) error { latest.Summary() } - ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context) - cmd.SetContext(ctx) + cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)) return nil } diff --git a/main.go b/main.go index 650e03a63..a9e92a311 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,8 @@ package main import ( "context" + "os" + "os/signal" "github.com/spf13/cobra" @@ -9,5 +11,15 @@ import ( ) func main() { - cobra.CheckErr(cmd.NewCLI().ExecuteContext(context.Background())) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt) + go func() { + <-sigChan + cancel() + }() + + cobra.CheckErr(cmd.NewCLI().ExecuteContext(ctx)) }