just for fun

This commit is contained in:
Roy Han 2024-07-25 17:34:11 -07:00
parent 918fd32884
commit 42009d2974
3 changed files with 256 additions and 140 deletions

View File

@ -2,6 +2,7 @@ package cmd
import (
"archive/zip"
"bufio"
"bytes"
"context"
"crypto/ed25519"
@ -16,6 +17,7 @@ import (
"net"
"net/http"
"os"
"os/exec"
"os/signal"
"path/filepath"
"regexp"
@ -31,6 +33,11 @@ import (
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
"gonum.org/v1/gonum/mat"
"gonum.org/v1/gonum/stat"
"gonum.org/v1/plot"
"gonum.org/v1/plot/plotter"
"gonum.org/v1/plot/vg"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
@ -370,6 +377,90 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generate(cmd, opts)
}
func EmbedHandler(cmd *cobra.Command, args []string) error {
interactive := true
opts := runOptions{
Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]interface{}{},
}
format, err := cmd.Flags().GetString("format")
if err != nil {
return err
}
opts.Format = format
keepAlive, err := cmd.Flags().GetString("keepalive")
if err != nil {
return err
}
if keepAlive != "" {
d, err := time.ParseDuration(keepAlive)
if err != nil {
return err
}
opts.KeepAlive = &api.Duration{Duration: d}
}
prompts := args[1:]
// prepend stdin to the prompt if provided
if !term.IsTerminal(int(os.Stdin.Fd())) {
in, err := io.ReadAll(os.Stdin)
if err != nil {
return err
}
prompts = append([]string{string(in)}, prompts...)
opts.WordWrap = false
interactive = false
}
opts.Prompt = strings.Join(prompts, " ")
if len(prompts) > 0 {
interactive = false
}
nowrap, err := cmd.Flags().GetBool("nowordwrap")
if err != nil {
return err
}
opts.WordWrap = !nowrap
// Fill out the rest of the options based on information about the
// model.
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
name := args[0]
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
var se api.StatusError
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
if err := PullHandler(cmd, []string{name}); err != nil {
return nil, err
}
return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
}
return info, err
}()
if err != nil {
return err
}
opts.MultiModal = slices.Contains(info.Details.Families, "clip")
opts.ParentModel = info.Details.ParentModel
opts.Messages = append(opts.Messages, info.Messages...)
if interactive {
return generateInteractive(cmd, opts)
}
return embed(cmd, opts)
}
func errFromUnknownKey(unknownKeyErr error) error {
// find SSH public key in the error message
sshKeyPattern := `ssh-\w+ [^\s"]+`
@ -979,6 +1070,154 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
return &api.Message{Role: role, Content: fullResponse.String()}, nil
}
func embed(cmd *cobra.Command, opts runOptions) error {
line := opts.Prompt
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return err
}
inputs := strings.Split(line, "\n\n")
req := &api.EmbedRequest{
Model: opts.Model,
Input: inputs,
}
resp, err := client.Embed(cmd.Context(), req)
if err != nil {
fmt.Println("error: couldn't get embeddings")
return err
}
embeddings := resp.Embeddings
r, c := len(embeddings), len(embeddings[0])
data := make([]float64, r*c)
for i := range r {
for j := range c {
data[i*c+j] = float64(embeddings[i][j])
}
}
X := mat.NewDense(r, c, data)
// Initialize PCA
var pca stat.PC
// Perform PCA
if !pca.PrincipalComponents(X, nil) {
return fmt.Errorf("PCA failed")
}
// Extract principal component vectors
var vectors mat.Dense
pca.VectorsTo(&vectors)
// // Extract variances of the principal components
// var variances []float64
// variances = pca.VarsTo(variances)
W := vectors.Slice(0, c, 0, 2).(*mat.Dense)
// Perform PCA reduction
var reducedData mat.Dense
reducedData.Mul(X, W)
for i, s := range inputs {
row := reducedData.RowView(i)
fmt.Print(i+1, ". ", s, "\n")
fmt.Printf("[%v, %v]\n\n", row.AtVec(0), row.AtVec(1))
}
points := make(plotter.XYs, reducedData.RawMatrix().Rows)
for i := range len(points) {
row := reducedData.RowView(i)
points[i].X = row.AtVec(0)
points[i].Y = row.AtVec(1)
}
// Create a new plot
p := plot.New()
// Set plot title and axis labels
p.Title.Text = "Embedding Map"
// Create a scatter plot of the points
s, err := plotter.NewScatter(points)
if err != nil {
panic(err)
}
p.Add(s)
/// Create labels plotter and add it to the plot
labels := make([]string, reducedData.RawMatrix().Rows)
for i := range len(labels) {
labels[i] = fmt.Sprintf("%d", i+1)
}
// plotter := plotter
l, err := plotter.NewLabels(plotter.XYLabels{XYs: points, Labels: labels})
if err != nil {
panic(err)
}
p.Add(l)
// Make the grid square
p.X.Min = -1
p.X.Max = 1
p.Y.Min = -1
p.Y.Max = 1
// Set the aspect ratio to be 1:1
p.X.Tick.Marker = plot.ConstantTicks([]plot.Tick{
{Value: -1, Label: "-1"},
{Value: -0.5, Label: "-0.5"},
{Value: 0, Label: "0"},
{Value: 0.5, Label: "0.5"},
{Value: 1, Label: "1"},
})
p.Y.Tick.Marker = plot.ConstantTicks([]plot.Tick{
{Value: -1, Label: "-1"},
{Value: -0.5, Label: "-0.5"},
{Value: 0, Label: "0"},
{Value: 0.5, Label: "0.5"},
{Value: 1, Label: "1"},
})
// Save the plot to a svg file
if err := p.Save(6*vg.Inch, 6*vg.Inch, "plot.svg"); err != nil {
panic(err)
}
// open the plot
open := exec.Command("open", "plot.svg")
err = open.Run()
if err != nil {
fmt.Println("error: couldn't open plot")
return err
}
// Wait for Enter key press
fmt.Print("Press 'Enter' to continue")
reader := bufio.NewReader(os.Stdin)
_, _ = reader.ReadString('\n')
// close and delete the plot (defer this)
defer func() {
delete := exec.Command("rm", "plot.svg")
err = delete.Run()
if err != nil {
fmt.Println("error: couldn't delete plot")
}
}()
return nil
}
func generate(cmd *cobra.Command, opts runOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
@ -1247,11 +1486,26 @@ func NewCLI() *cobra.Command {
RunE: RunHandler,
}
embedCmd := &cobra.Command{
Use: "embed MODEL [PROMPT]",
Short: "Embed a model",
Args: cobra.MinimumNArgs(1),
PreRunE: checkServerHeartbeat,
RunE: EmbedHandler,
}
runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
runCmd.Flags().Bool("verbose", false, "Show timings for response")
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
runCmd.Flags().String("format", "", "Response format (e.g. json)")
embedCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
embedCmd.Flags().Bool("verbose", false, "Show timings for response")
embedCmd.Flags().Bool("insecure", false, "Use an insecure registry")
embedCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
embedCmd.Flags().String("format", "", "Response format (e.g. json)")
serveCmd := &cobra.Command{
Use: "serve",
Aliases: []string{"start"},
@ -1326,6 +1580,7 @@ func NewCLI() *cobra.Command {
copyCmd,
deleteCmd,
serveCmd,
embedCmd,
} {
switch cmd {
case runCmd:
@ -1361,6 +1616,7 @@ func NewCLI() *cobra.Command {
psCmd,
copyCmd,
deleteCmd,
embedCmd,
)
return rootCmd

View File

@ -1,7 +1,6 @@
package cmd
import (
"encoding/json"
"errors"
"fmt"
"io"
@ -14,11 +13,6 @@ import (
"strings"
"github.com/spf13/cobra"
"gonum.org/v1/gonum/mat"
"gonum.org/v1/gonum/stat"
"gonum.org/v1/plot"
"gonum.org/v1/plot/plotter"
"gonum.org/v1/plot/vg"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
@ -453,140 +447,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
}
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
return nil
case strings.HasPrefix(line, "/embed"):
line = strings.TrimPrefix(line, "/embed")
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return err
}
var strArray []string
fmt.Printf("line is %s\n", line)
err = json.Unmarshal([]byte(line), &strArray)
if err != nil {
fmt.Println("error: couldn't parse input")
return err
}
for i, s := range strArray {
fmt.Printf("strArray[%d] is %s\n", i, s)
}
req := &api.EmbedRequest{
Model: opts.Model,
Input: strArray,
}
resp, err := client.Embed(cmd.Context(), req)
if err != nil {
fmt.Println("error: couldn't get embeddings")
return err
}
embeddings := resp.Embeddings
r, c := len(embeddings), len(embeddings[0])
data := make([]float64, r*c)
for i := 0; i < r; i++ {
for j := 0; j < c; j++ {
data[i*c+j] = float64(embeddings[i][j])
}
}
X := mat.NewDense(r, c, data)
// Initialize PCA
var pca stat.PC
// Perform PCA
if !pca.PrincipalComponents(X, nil) {
return fmt.Errorf("PCA failed")
}
// Extract principal component vectors
var vectors mat.Dense
pca.VectorsTo(&vectors)
// // Extract variances of the principal components
// var variances []float64
// variances = pca.VarsTo(variances)
W := vectors.Slice(0, c, 0, 2).(*mat.Dense)
// Perform PCA reduction
var reducedData mat.Dense
reducedData.Mul(X, W)
// Print the projected 2D points
fmt.Println("Reduced embeddings to 2D:")
for i := 0; i < reducedData.RawMatrix().Rows; i++ {
row := reducedData.RowView(i)
fmt.Printf("[%v, %v]\n", row.AtVec(0), row.AtVec(1))
}
points := make(plotter.XYs, reducedData.RawMatrix().Rows)
for i := 0; i < reducedData.RawMatrix().Rows; i++ {
row := reducedData.RowView(i)
points[i].X = row.AtVec(0)
points[i].Y = row.AtVec(1)
}
// Create a new plot
p := plot.New()
// Set plot title and axis labels
p.Title.Text = "2D Data Plot"
p.X.Label.Text = "X"
p.Y.Label.Text = "Y"
// Create a scatter plot of the points
s, err := plotter.NewScatter(points)
if err != nil {
panic(err)
}
p.Add(s)
/// Create labels plotter and add it to the plot
labels := make([]string, reducedData.RawMatrix().Rows)
for i := 0; i < reducedData.RawMatrix().Rows; i++ {
labels[i] = fmt.Sprintf("%d", i+1)
}
l, err := plotter.NewLabels(plotter.XYLabels{XYs: points, Labels: labels})
if err != nil {
panic(err)
}
p.Add(l)
// Make the grid square
p.X.Min = -1
p.X.Max = 1
p.Y.Min = -1
p.Y.Max = 1
// Set the aspect ratio to be 1:1
p.X.Tick.Marker = plot.ConstantTicks([]plot.Tick{
{Value: -1, Label: "-1"},
{Value: -0.5, Label: "-0.5"},
{Value: 0, Label: "0"},
{Value: 0.5, Label: "0.5"},
{Value: 1, Label: "1"},
})
p.Y.Tick.Marker = plot.ConstantTicks([]plot.Tick{
{Value: -1, Label: "-1"},
{Value: -0.5, Label: "-0.5"},
{Value: 0, Label: "0"},
{Value: 0.5, Label: "0.5"},
{Value: 1, Label: "1"},
})
// Save the plot to a PNG file
if err := p.Save(6*vg.Inch, 6*vg.Inch, "plot.png"); err != nil {
panic(err)
}
case strings.HasPrefix(line, "/"):
args := strings.Fields(line)
isFile := false

BIN
plot.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.2 KiB