Skip to content

Commit

Permalink
fix(progressbar): hopefully fix the progress bar updating (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
sammcj authored Dec 29, 2024
1 parent cb6d882 commit 12481a1
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 21 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The application allows users to interactively select models, sort, filter, edit,
- [go install (recommended)](#go-install-recommended)
- [curl](#curl)
- [Manually](#manually)
- [if "command not found: gollama"](#if-command-not-found-gollama)
- [Usage](#usage)
- [Key Bindings](#key-bindings)
- [Top](#top)
Expand Down Expand Up @@ -110,11 +111,11 @@ echo "alias g=gollama" >> ~/.zshrc
- `i`: Inspect model
- `t`: Top (show running models)
- `D`: Delete model
- `e`: Edit model **new**
- `e`: Edit model
- `c`: Copy model
- `U`: Unload all models
- `p`: Pull an existing model **new**
- `g`: Pull (get) new model **new**
- `p`: Pull an existing model
- `ctrl+p`: Pull (get) new model
- `P`: Push model
- `n`: Sort by name
- `s`: Sort by size
Expand Down Expand Up @@ -159,7 +160,7 @@ Note: Requires Admin privileges if you're running Windows.
- `-u`: Unload all running models
- `-v`: Print the version and exit
- `-h`, or `--host`: Specify the host for the Ollama API
- `-H`: Shortcut for `-h http://localhost:11434` (connect to local Ollama API) **new**
- `-H`: Shortcut for `-h http://localhost:11434` (connect to local Ollama API)
- `--vram`: Estimate vRAM usage for a model. Accepts:
- Ollama models (e.g. `llama3.1:8b-instruct-q6_K`, `qwen2:14b-q4_0`)
- HuggingFace models (e.g. `NousResearch/Hermes-2-Theta-Llama-3-8B`)
Expand Down
14 changes: 11 additions & 3 deletions app_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,17 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m.handlePullErrorMsg(msg)
case progressMsg:
if m.pullProgress < 1.0 {
m.pullProgress = msg.progress
return m, m.updateProgressCmd()
return m, tea.Batch(
m.updateProgressCmd(),
func() tea.Msg {
return progressMsg{
modelName: msg.modelName,
progress: m.pullProgress,
}
},
)
}
return m, nil
}
}
switch msg := msg.(type) {
Expand Down Expand Up @@ -412,7 +420,7 @@ func (m *AppModel) handlePullErrorMsg(msg pullErrorMsg) (tea.Model, tea.Cmd) {
}

func (m *AppModel) updateProgressCmd() tea.Cmd {
return tea.Tick(time.Millisecond*100, func(t time.Time) tea.Msg {
return tea.Tick(time.Second, func(t time.Time) tea.Msg {
return progressMsg{
modelName: m.pullInput.Value(),
progress: m.pullProgress,
Expand Down
2 changes: 1 addition & 1 deletion keymap.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func NewKeyMap() *KeyMap {
LinkModel: key.NewBinding(key.WithKeys("l"), key.WithHelp("l", "link (L=all)")),
PushModel: key.NewBinding(key.WithKeys("P"), key.WithHelp("P", "push")),
PullModel: key.NewBinding(key.WithKeys("p"), key.WithHelp("p", "pull")),
PullNewModel: key.NewBinding(key.WithKeys("g"), key.WithHelp("g", "get")),
PullNewModel: key.NewBinding(key.WithKeys("ctrl+p"), key.WithHelp("ctrl+p", "pull new model")),
Quit: key.NewBinding(key.WithKeys("q")),
RunModel: key.NewBinding(key.WithKeys("enter"), key.WithHelp("enter", "run")),
SortByFamily: key.NewBinding(key.WithKeys("f"), key.WithHelp("f", "^family")),
Expand Down
58 changes: 45 additions & 13 deletions operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,22 +96,54 @@ func (m *AppModel) startPullModel(modelName string) tea.Cmd {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

req := &api.PullRequest{Name: modelName}
err := m.client.Pull(ctx, req, func(resp api.ProgressResponse) error {
if !m.pulling {
return context.Canceled
progressChan := make(chan float64)
errChan := make(chan error)

go func() {
req := &api.PullRequest{Name: modelName}
err := m.client.Pull(ctx, req, func(resp api.ProgressResponse) error {
if !m.pulling {
return context.Canceled
}
progress := float64(resp.Completed) / float64(resp.Total)
m.pullProgress = progress
progressChan <- progress
return nil
})

if err == context.Canceled {
errChan <- fmt.Errorf("pull cancelled")
return
}
m.pullProgress = float64(resp.Completed) / float64(resp.Total)
return nil
})
if err != nil {
errChan <- err
return
}
close(progressChan)
}()

if err == context.Canceled {
return pullErrorMsg{fmt.Errorf("pull cancelled")}
}
if err != nil {
return pullErrorMsg{err}
// Start a ticker to send progress updates
ticker := time.NewTicker(time.Second)
defer ticker.Stop()

for {
select {
case err := <-errChan:
if err != nil {
return pullErrorMsg{err}
}
return pullSuccessMsg{modelName}
case <-ticker.C:
return progressMsg{
modelName: modelName,
progress: m.pullProgress,
}
case progress := <-progressChan:
if progress >= 1.0 {
return pullSuccessMsg{modelName}
}
}
}
return pullSuccessMsg{modelName}
}
}

Expand Down

0 comments on commit 12481a1

Please sign in to comment.