Compare commits

...

2 Commits

Author SHA1 Message Date
Michael Yang
fcfbb06f1b cmd: handle sigint globally
This change also updates both client.do and client.stream to return
ctx.Err(). Previously this error is skipped so canceled contexts are
silently ignored
2025-02-19 10:46:25 -08:00
Michael Yang
e8d35d0de0 cmd: fix hide cursor
hides the cursor for the entire progress rather than each render cycle
2025-02-19 09:43:44 -08:00
4 changed files with 36 additions and 56 deletions

View File

@ -126,7 +126,8 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return err
}
}
return nil
return ctx.Err()
}
const maxBufferSize = 512 * format.KiloByte
@ -189,7 +190,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
}
}
return nil
return ctx.Err()
}
// GenerateResponseFunc is a function that [Client.Generate] invokes every time

View File

@ -15,13 +15,11 @@ import (
"net"
"net/http"
"os"
"os/signal"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync/atomic"
"syscall"
"time"
"github.com/containerd/console"
@ -330,6 +328,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
if err := PullHandler(cmd, []string{name}); err != nil {
return nil, err
}
return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
}
return info, err
@ -858,17 +857,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
spinner := progress.NewSpinner("")
p.Add("", spinner)
cancelCtx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT)
go func() {
<-sigChan
cancel()
}()
var state *displayResponseState = &displayResponseState{}
var latest api.ChatResponse
var fullResponse strings.Builder
@ -903,10 +891,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
req.KeepAlive = opts.KeepAlive
}
if err := client.Chat(cancelCtx, req, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil, nil
}
if err := client.Chat(cmd.Context(), req, fn); err != nil {
return nil, err
}
@ -946,17 +931,6 @@ func generate(cmd *cobra.Command, opts runOptions) error {
generateContext = []int{}
}
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT)
go func() {
<-sigChan
cancel()
}()
var state *displayResponseState = &displayResponseState{}
fn := func(response api.GenerateResponse) error {
@ -992,10 +966,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
KeepAlive: opts.KeepAlive,
}
if err := client.Generate(ctx, &request, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil
}
if err := client.Generate(cmd.Context(), &request, fn); err != nil {
return err
}
@ -1017,8 +988,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
latest.Summary()
}
ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
cmd.SetContext(ctx)
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context))
return nil
}

14
main.go
View File

@ -2,6 +2,8 @@ package main
import (
"context"
"os"
"os/signal"
"github.com/spf13/cobra"
@ -9,5 +11,15 @@ import (
)
func main() {
cobra.CheckErr(cmd.NewCLI().ExecuteContext(context.Background()))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt)
go func() {
<-sigChan
cancel()
}()
cobra.CheckErr(cmd.NewCLI().ExecuteContext(ctx))
}

View File

@ -49,29 +49,29 @@ func (p *Progress) stop() bool {
func (p *Progress) Stop() bool {
stopped := p.stop()
if stopped {
fmt.Fprint(p.w, "\n")
p.w.Flush()
fmt.Fprintln(p.w)
}
// show cursor
fmt.Fprint(p.w, "\033[?25h")
p.w.Flush()
return stopped
}
func (p *Progress) StopAndClear() bool {
defer p.w.Flush()
fmt.Fprint(p.w, "\033[?25l")
defer fmt.Fprint(p.w, "\033[?25h")
stopped := p.stop()
if stopped {
// clear all progress lines
for i := range p.pos {
if i > 0 {
fmt.Fprint(p.w, "\033[A")
}
fmt.Fprint(p.w, "\033[2K\033[1G")
for range p.pos - 1 {
fmt.Fprint(p.w, "\033[A")
}
fmt.Fprint(p.w, "\033[2K", "\033[1G")
}
// show cursor
fmt.Fprint(p.w, "\033[?25h")
p.w.Flush()
return stopped
}
@ -86,19 +86,13 @@ func (p *Progress) render() {
p.mu.Lock()
defer p.mu.Unlock()
defer p.w.Flush()
// eliminate flickering on terminals that support synchronized output
fmt.Fprint(p.w, "\033[?2026h")
defer fmt.Fprint(p.w, "\033[?2026l")
fmt.Fprint(p.w, "\033[?25l")
defer fmt.Fprint(p.w, "\033[?25h")
// move the cursor back to the beginning
for range p.pos - 1 {
fmt.Fprint(p.w, "\033[A")
}
fmt.Fprint(p.w, "\033[1G")
// render progress lines
@ -110,10 +104,13 @@ func (p *Progress) render() {
}
p.pos = len(p.states)
p.w.Flush()
}
func (p *Progress) start() {
p.ticker = time.NewTicker(100 * time.Millisecond)
// hide cursor
fmt.Fprint(p.w, "\033[?25l")
for range p.ticker.C {
p.render()
}