Skip to content

Commit ea96ead

Browse files
committed
feat(tui): handle --model and --prompt flags
1 parent 6100a77 commit ea96ead

File tree

4 files changed

+112
-81
lines changed

4 files changed

+112
-81
lines changed

packages/tui/cmd/opencode/main.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"strings"
1010

1111
tea "github.com/charmbracelet/bubbletea/v2"
12+
flag "github.com/spf13/pflag"
1213
"github.com/sst/opencode-sdk-go"
1314
"github.com/sst/opencode-sdk-go/option"
1415
"github.com/sst/opencode/internal/app"
@@ -23,6 +24,10 @@ func main() {
2324
version = "v" + Version
2425
}
2526

27+
var model *string = flag.String("model", "", "model to begin with")
28+
var prompt *string = flag.String("prompt", "", "prompt to begin with")
29+
flag.Parse()
30+
2631
url := os.Getenv("OPENCODE_SERVER")
2732

2833
appInfoStr := os.Getenv("OPENCODE_APP_INFO")
@@ -65,7 +70,7 @@ func main() {
6570
ctx, cancel := context.WithCancel(context.Background())
6671
defer cancel()
6772

68-
app_, err := app.New(ctx, version, appInfo, httpClient)
73+
app_, err := app.New(ctx, version, appInfo, httpClient, model, prompt)
6974
if err != nil {
7075
panic(err)
7176
}

packages/tui/go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ require (
1717
github.com/muesli/termenv v0.16.0
1818
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3
1919
github.com/sst/opencode-sdk-go v0.1.0-alpha.8
20-
github.com/tidwall/gjson v1.14.4
2120
rsc.io/qr v0.2.0
2221
)
2322

@@ -50,6 +49,7 @@ require (
5049
github.com/sosodev/duration v1.3.1 // indirect
5150
github.com/speakeasy-api/openapi-overlay v0.9.0 // indirect
5251
github.com/spf13/cobra v1.9.1 // indirect
52+
github.com/tidwall/gjson v1.14.4 // indirect
5353
github.com/tidwall/match v1.1.1 // indirect
5454
github.com/tidwall/pretty v1.2.1 // indirect
5555
github.com/tidwall/sjson v1.2.5 // indirect

packages/tui/internal/app/app.go

Lines changed: 93 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,19 @@ import (
2121
)
2222

2323
type App struct {
24-
Info opencode.App
25-
Version string
26-
StatePath string
27-
Config *opencode.Config
28-
Client *opencode.Client
29-
State *config.State
30-
Provider *opencode.Provider
31-
Model *opencode.Model
32-
Session *opencode.Session
33-
Messages []opencode.MessageUnion
34-
Commands commands.CommandRegistry
24+
Info opencode.App
25+
Version string
26+
StatePath string
27+
Config *opencode.Config
28+
Client *opencode.Client
29+
State *config.State
30+
Provider *opencode.Provider
31+
Model *opencode.Model
32+
Session *opencode.Session
33+
Messages []opencode.MessageUnion
34+
Commands commands.CommandRegistry
35+
InitialModel *string
36+
InitialPrompt *string
3537
}
3638

3739
type SessionSelectedMsg = *opencode.Session
@@ -58,6 +60,8 @@ func New(
5860
version string,
5961
appInfo opencode.App,
6062
httpClient *opencode.Client,
63+
model *string,
64+
prompt *string,
6165
) (*App, error) {
6266
util.RootPath = appInfo.Path.Root
6367
util.CwdPath = appInfo.Path.Cwd
@@ -109,15 +113,17 @@ func New(
109113
slog.Debug("Loaded config", "config", configInfo)
110114

111115
app := &App{
112-
Info: appInfo,
113-
Version: version,
114-
StatePath: appStatePath,
115-
Config: configInfo,
116-
State: appState,
117-
Client: httpClient,
118-
Session: &opencode.Session{},
119-
Messages: []opencode.MessageUnion{},
120-
Commands: commands.LoadFromConfig(configInfo),
116+
Info: appInfo,
117+
Version: version,
118+
StatePath: appStatePath,
119+
Config: configInfo,
120+
State: appState,
121+
Client: httpClient,
122+
Session: &opencode.Session{},
123+
Messages: []opencode.MessageUnion{},
124+
Commands: commands.LoadFromConfig(configInfo),
125+
InitialModel: model,
126+
InitialPrompt: prompt,
121127
}
122128

123129
return app, nil
@@ -141,65 +147,89 @@ func (a *App) Key(commandName commands.CommandName) string {
141147
}
142148

143149
func (a *App) InitializeProvider() tea.Cmd {
144-
return func() tea.Msg {
145-
providersResponse, err := a.Client.Config.Providers(context.Background())
146-
if err != nil {
147-
slog.Error("Failed to list providers", "error", err)
148-
// TODO: notify user
149-
return nil
150+
providersResponse, err := a.Client.Config.Providers(context.Background())
151+
if err != nil {
152+
slog.Error("Failed to list providers", "error", err)
153+
// TODO: notify user
154+
return nil
155+
}
156+
providers := providersResponse.Providers
157+
var defaultProvider *opencode.Provider
158+
var defaultModel *opencode.Model
159+
160+
var anthropic *opencode.Provider
161+
for _, provider := range providers {
162+
if provider.ID == "anthropic" {
163+
anthropic = &provider
150164
}
151-
providers := providersResponse.Providers
152-
var defaultProvider *opencode.Provider
153-
var defaultModel *opencode.Model
165+
}
154166

155-
var anthropic *opencode.Provider
156-
for _, provider := range providers {
157-
if provider.ID == "anthropic" {
158-
anthropic = &provider
159-
}
160-
}
167+
// default to anthropic if available
168+
if anthropic != nil {
169+
defaultProvider = anthropic
170+
defaultModel = getDefaultModel(providersResponse, *anthropic)
171+
}
161172

162-
// default to anthropic if available
163-
if anthropic != nil {
164-
defaultProvider = anthropic
165-
defaultModel = getDefaultModel(providersResponse, *anthropic)
173+
for _, provider := range providers {
174+
if defaultProvider == nil || defaultModel == nil {
175+
defaultProvider = &provider
176+
defaultModel = getDefaultModel(providersResponse, provider)
166177
}
178+
providers = append(providers, provider)
179+
}
180+
if len(providers) == 0 {
181+
slog.Error("No providers configured")
182+
return nil
183+
}
167184

168-
for _, provider := range providers {
169-
if defaultProvider == nil || defaultModel == nil {
170-
defaultProvider = &provider
171-
defaultModel = getDefaultModel(providersResponse, provider)
185+
var currentProvider *opencode.Provider
186+
var currentModel *opencode.Model
187+
for _, provider := range providers {
188+
if provider.ID == a.State.Provider {
189+
currentProvider = &provider
190+
191+
for _, model := range provider.Models {
192+
if model.ID == a.State.Model {
193+
currentModel = &model
194+
}
172195
}
173-
providers = append(providers, provider)
174-
}
175-
if len(providers) == 0 {
176-
slog.Error("No providers configured")
177-
return nil
178196
}
197+
}
198+
if currentProvider == nil || currentModel == nil {
199+
currentProvider = defaultProvider
200+
currentModel = defaultModel
201+
}
179202

180-
var currentProvider *opencode.Provider
181-
var currentModel *opencode.Model
203+
var initialProvider *opencode.Provider
204+
var initialModel *opencode.Model
205+
if a.InitialModel != nil && *a.InitialModel != "" {
206+
splits := strings.Split(*a.InitialModel, "/")
182207
for _, provider := range providers {
183-
if provider.ID == a.State.Provider {
184-
currentProvider = &provider
185-
208+
if provider.ID == splits[0] {
209+
initialProvider = &provider
186210
for _, model := range provider.Models {
187-
if model.ID == a.State.Model {
188-
currentModel = &model
211+
if model.ID == splits[1] {
212+
initialModel = &model
189213
}
190214
}
191215
}
192216
}
193-
if currentProvider == nil || currentModel == nil {
194-
currentProvider = defaultProvider
195-
currentModel = defaultModel
196-
}
217+
}
197218

198-
return ModelSelectedMsg{
199-
Provider: *currentProvider,
200-
Model: *currentModel,
201-
}
219+
if initialProvider != nil && initialModel != nil {
220+
currentProvider = initialProvider
221+
currentModel = initialModel
222+
}
223+
224+
var cmds []tea.Cmd
225+
cmds = append(cmds, util.CmdHandler(ModelSelectedMsg{
226+
Provider: *currentProvider,
227+
Model: *currentModel,
228+
}))
229+
if a.InitialPrompt != nil && *a.InitialPrompt != "" {
230+
cmds = append(cmds, util.CmdHandler(SendMsg{Text: *a.InitialPrompt}))
202231
}
232+
return tea.Sequence(cmds...)
203233
}
204234

205235
func getDefaultModel(

packages/tui/internal/components/chat/editor.go

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func (m *editorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
6464
return m, tea.Batch(cmds...)
6565
}
6666
case dialog.ThemeSelectedMsg:
67-
m.textarea = createTextArea(&m.textarea)
67+
m.textarea = m.resetTextareaStyles()
6868
m.spinner = createSpinner()
6969
return m, tea.Batch(m.spinner.Tick, m.textarea.Focus())
7070
case dialog.CompletionSelectedMsg:
@@ -306,13 +306,13 @@ func (m *editorComponent) getSubmitKeyText() string {
306306
return m.app.Commands[commands.InputSubmitCommand].Keys()[0]
307307
}
308308

309-
func createTextArea(existing *textarea.Model) textarea.Model {
309+
func (m *editorComponent) resetTextareaStyles() textarea.Model {
310310
t := theme.CurrentTheme()
311311
bgColor := t.BackgroundElement()
312312
textColor := t.Text()
313313
textMutedColor := t.TextMuted()
314314

315-
ta := textarea.New()
315+
ta := m.textarea
316316

317317
ta.Styles.Blurred.Base = styles.NewStyle().Foreground(textColor).Background(bgColor).Lipgloss()
318318
ta.Styles.Blurred.CursorLine = styles.NewStyle().Background(bgColor).Lipgloss()
@@ -337,17 +337,6 @@ func createTextArea(existing *textarea.Model) textarea.Model {
337337
Background(t.Secondary()).
338338
Lipgloss()
339339
ta.Styles.Cursor.Color = t.Primary()
340-
341-
ta.Prompt = " "
342-
ta.ShowLineNumbers = false
343-
ta.CharLimit = -1
344-
345-
if existing != nil {
346-
ta.SetValue(existing.Value())
347-
// ta.SetWidth(existing.Width())
348-
ta.SetHeight(existing.Height())
349-
}
350-
351340
return ta
352341
}
353342

@@ -367,12 +356,19 @@ func createSpinner() spinner.Model {
367356

368357
func NewEditorComponent(app *app.App) EditorComponent {
369358
s := createSpinner()
370-
ta := createTextArea(nil)
371359

372-
return &editorComponent{
360+
ta := textarea.New()
361+
ta.Prompt = " "
362+
ta.ShowLineNumbers = false
363+
ta.CharLimit = -1
364+
365+
m := &editorComponent{
373366
app: app,
374367
textarea: ta,
375368
spinner: s,
376369
interruptKeyInDebounce: false,
377370
}
371+
m.resetTextareaStyles()
372+
373+
return m
378374
}

0 commit comments

Comments
 (0)