charmbracelet / huh

Build terminal forms and prompts 🤷🏻‍♀️
MIT License
3.7k stars 94 forks source link

Get input as an integer or other type #297

Open intercepted16 opened 1 week ago

intercepted16 commented 1 week ago

Is your feature request related to a problem? Please describe. Yes; I am trying to get an integer from the user through the NewInput function, and it seems that I can't because the type expected in Value is a pointer to a string, not an integer.

Describe the solution you'd like An option to specify the type as a generic type argument.

Describe alternatives you've considered An alternative is to provide a function that takes a function as an argument which converts the specified value to an integer.

Additional context Add any other context or screenshots about the feature request here.

shaunco commented 5 days ago

Agreed, would be great to have the ability to add custom Marshal/Unmarshal functions and a ValueAny that can be used instead of Value so that the conversion to/from a string is usable for virtually any type. Would also need a ValidateAny that would take interface{} or any on the callback function ... or it would be fine with validation still happening at the string level before Unmarshal is called.

shaunco commented 4 days ago

I took the existing Text and Input types and made them generic with marshal/unmarshal functions. You can drop these in your own project somewhere rather than waiting for something like this in huh.

InputAny[T any] (a generic version of Input) ### field_inputany.go ```go package shell import ( "fmt" "strings" "github.com/charmbracelet/bubbles/key" "github.com/charmbracelet/bubbles/textinput" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/huh" "github.com/charmbracelet/huh/accessibility" "github.com/charmbracelet/lipgloss" ) // InputAny[T] is a generic form input field. type InputAny[T any] struct { value *T strValue *string key string // customization title string description string inline bool // marshaling T<->string marshal func(T) string unmarshal func(string) (T, error) // error handling validate func(T) error err error // model textinput textinput.Model // state focused bool // options width int height int accessible bool theme *huh.Theme keymap huh.InputKeyMap } // NewInputAny returns a new input field. func NewInputAny[T any](marshal func(T) string, unmarshal func(string) (T, error)) *InputAny[T] { input := textinput.New() i := &InputAny[T]{ value: new(T), strValue: new(string), textinput: input, validate: func(T) error { return nil }, marshal: marshal, unmarshal: unmarshal, } return i } // Value sets the value of the input field. func (i *InputAny[T]) Value(value *T) *InputAny[T] { i.value = value strValue := i.marshal(*value) i.strValue = &strValue i.textinput.SetValue(*i.strValue) return i } // Key sets the key of the input field. func (i *InputAny[T]) Key(key string) *InputAny[T] { i.key = key return i } // Title sets the title of the input field. func (i *InputAny[T]) Title(title string) *InputAny[T] { i.title = title return i } // Description sets the description of the input field. func (i *InputAny[T]) Description(description string) *InputAny[T] { i.description = description return i } // Prompt sets the prompt of the input field. func (i *InputAny[T]) Prompt(prompt string) *InputAny[T] { i.textinput.Prompt = prompt return i } // CharLimit sets the character limit of the input field. func (i *InputAny[T]) CharLimit(charlimit int) *InputAny[T] { i.textinput.CharLimit = charlimit return i } // Suggestions sets the suggestions to display for autocomplete in the input // field. func (i *InputAny[T]) Suggestions(suggestions []string) *InputAny[T] { i.textinput.ShowSuggestions = len(suggestions) > 0 i.textinput.KeyMap.AcceptSuggestion.SetEnabled(len(suggestions) > 0) i.textinput.SetSuggestions(suggestions) return i } // EchoMode sets the input behavior of the text Input field. type EchoMode textinput.EchoMode const ( // EchoNormal displays text as is. // This is the default behavior. EchoModeNormal EchoMode = EchoMode(textinput.EchoNormal) // EchoPassword displays the EchoCharacter mask instead of actual characters. // This is commonly used for password fields. EchoModePassword EchoMode = EchoMode(textinput.EchoPassword) // EchoNone displays nothing as characters are entered. // This is commonly seen for password fields on the command line. EchoModeNone EchoMode = EchoMode(textinput.EchoNone) ) // EchoMode sets the echo mode of the input. func (i *InputAny[T]) EchoMode(mode EchoMode) *InputAny[T] { i.textinput.EchoMode = textinput.EchoMode(mode) return i } // Password sets whether or not to hide the input while the user is typing. // // Deprecated: use EchoMode(EchoPassword) instead. func (i *InputAny[T]) Password(password bool) *InputAny[T] { if password { i.textinput.EchoMode = textinput.EchoPassword } else { i.textinput.EchoMode = textinput.EchoNormal } return i } // Placeholder sets the placeholder of the text input. func (i *InputAny[T]) Placeholder(str string) *InputAny[T] { i.textinput.Placeholder = str return i } // Inline sets whether the title and input should be on the same line. func (i *InputAny[T]) Inline(inline bool) *InputAny[T] { i.inline = inline return i } // Validate sets the validation function of the input field. func (i *InputAny[T]) Validate(validate func(T) error) *InputAny[T] { i.validate = validate return i } // Error returns the error of the input field. func (i *InputAny[T]) Error() error { return i.err } // Skip returns whether the input should be skipped or should be blocking. func (*InputAny[T]) Skip() bool { return false } // Zoom returns whether the input should be zoomed. func (*InputAny[T]) Zoom() bool { return false } // Focus focuses the input field. func (i *InputAny[T]) Focus() tea.Cmd { i.focused = true return i.textinput.Focus() } // Blur blurs the input field. func (i *InputAny[T]) Blur() tea.Cmd { i.focused = false *i.strValue = i.textinput.Value() i.textinput.Blur() *i.value, i.err = i.unmarshal(*i.strValue) if i.err == nil { i.err = i.validate(*i.value) } return nil } // KeyBinds returns the help message for the input field. func (i *InputAny[T]) KeyBinds() []key.Binding { if i.textinput.ShowSuggestions { return []key.Binding{i.keymap.AcceptSuggestion, i.keymap.Prev, i.keymap.Submit, i.keymap.Next} } return []key.Binding{i.keymap.Prev, i.keymap.Submit, i.keymap.Next} } // Init initializes the input field. func (i *InputAny[T]) Init() tea.Cmd { i.textinput.Blur() return nil } // Update updates the input field. func (i *InputAny[T]) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmds []tea.Cmd var cmd tea.Cmd i.textinput, cmd = i.textinput.Update(msg) cmds = append(cmds, cmd) *i.strValue = i.textinput.Value() switch msg := msg.(type) { case tea.KeyMsg: i.err = nil switch { case key.Matches(msg, i.keymap.Prev): i.saveValue() if i.err != nil { return i, nil } cmds = append(cmds, huh.PrevField) case key.Matches(msg, i.keymap.Next, i.keymap.Submit): i.saveValue() if i.err != nil { return i, nil } cmds = append(cmds, huh.NextField) } } return i, tea.Batch(cmds...) } func (i *InputAny[T]) saveValue() { strValue := i.textinput.Value() value, err := i.unmarshal(strValue) if err == nil { err = i.validate(value) if err != nil { *i.value = value } } i.err = err } func (i *InputAny[T]) activeStyles() *huh.FieldStyles { theme := i.theme if theme == nil { theme = huh.ThemeCharm() } if i.focused { return &theme.Focused } return &theme.Blurred } // View renders the input field. func (i *InputAny[T]) View() string { styles := i.activeStyles() // NB: since the method is on a pointer receiver these are being mutated. // Because this runs on every render this shouldn't matter in practice, // however. i.textinput.PlaceholderStyle = styles.TextInput.Placeholder i.textinput.PromptStyle = styles.TextInput.Prompt i.textinput.Cursor.Style = styles.TextInput.Cursor i.textinput.TextStyle = styles.TextInput.Text var sb strings.Builder if i.title != "" { sb.WriteString(styles.Title.Render(i.title)) if !i.inline { sb.WriteString("\n") } } if i.description != "" { sb.WriteString(styles.Description.Render(i.description)) if !i.inline { sb.WriteString("\n") } } sb.WriteString(i.textinput.View()) return styles.Base.Render(sb.String()) } // Run runs the input field in accessible mode. func (i *InputAny[T]) Run() error { if i.accessible { return i.runAccessible() } return i.run() } // run runs the input field. func (i *InputAny[T]) run() error { return huh.Run(i) } // runAccessible runs the input field in accessible mode. func (i *InputAny[T]) runAccessible() error { styles := i.activeStyles() fmt.Println(styles.Title.Render(i.title)) fmt.Println() *i.strValue = accessibility.PromptString("Input: ", func(input string) error { value, err := i.unmarshal(input) if err != nil { return err } err = i.validate(value) if err != nil { return err } return nil }) fmt.Println(styles.SelectedOption.Render("Input: " + *i.strValue + "\n")) return nil } // WithKeyMap sets the keymap on an input field. func (i *InputAny[T]) WithKeyMap(k *huh.KeyMap) huh.Field { i.keymap = k.Input i.textinput.KeyMap.AcceptSuggestion = i.keymap.AcceptSuggestion return i } // WithAccessible sets the accessible mode of the input field. func (i *InputAny[T]) WithAccessible(accessible bool) huh.Field { i.accessible = accessible return i } // WithTheme sets the theme of the input field. func (i *InputAny[T]) WithTheme(theme *huh.Theme) huh.Field { if i.theme != nil { return i } i.theme = theme return i } // WithWidth sets the width of the input field. func (i *InputAny[T]) WithWidth(width int) huh.Field { styles := i.activeStyles() i.width = width frameSize := styles.Base.GetHorizontalFrameSize() promptWidth := lipgloss.Width(i.textinput.PromptStyle.Render(i.textinput.Prompt)) titleWidth := lipgloss.Width(styles.Title.Render(i.title)) descriptionWidth := lipgloss.Width(styles.Description.Render(i.description)) i.textinput.Width = width - frameSize - promptWidth - 1 if i.inline { i.textinput.Width -= titleWidth i.textinput.Width -= descriptionWidth } return i } // WithHeight sets the height of the input field. func (i *InputAny[T]) WithHeight(height int) huh.Field { i.height = height return i } // WithPosition sets the position of the input field. func (i *InputAny[T]) WithPosition(p huh.FieldPosition) huh.Field { i.keymap.Prev.SetEnabled(!p.IsFirst()) i.keymap.Next.SetEnabled(!p.IsLast()) i.keymap.Submit.SetEnabled(p.IsLast()) return i } // GetKey returns the key of the field. func (i *InputAny[T]) GetKey() string { return i.key } // GetValue returns the value of the field. func (i *InputAny[T]) GetValue() any { return *i.value } ```
TextAny[T any] (a generic version of Text) ### field_textany.go ```go package mypackage import ( "fmt" "os" "os/exec" "strings" "github.com/charmbracelet/bubbles/key" "github.com/charmbracelet/bubbles/textarea" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/huh" "github.com/charmbracelet/huh/accessibility" "github.com/charmbracelet/lipgloss" ) // TextAny[T] is a generic form text field. It allows for a multi-line string input. type TextAny[T any] struct { value *T strValue *string key string // marshaling T<->string marshal func(T) string unmarshal func(string) (T, error) // error handling validate func(T) error err error // model textarea textarea.Model // customization title string description string editorCmd string editorArgs []string editorExtension string // state focused bool // form options width int accessible bool theme *huh.Theme keymap huh.TextKeyMap } // NewTextAny returns a new text field. func NewTextAny[T any](marshal func(T) string, unmarshal func(string) (T, error)) *TextAny[T] { text := textarea.New() text.ShowLineNumbers = false text.Prompt = "" text.FocusedStyle.CursorLine = lipgloss.NewStyle() editorCmd, editorArgs := getEditor() t := &TextAny[T]{ value: new(T), strValue: new(string), textarea: text, validate: func(T) error { return nil }, marshal: marshal, unmarshal: unmarshal, editorCmd: editorCmd, editorArgs: editorArgs, editorExtension: "md", } return t } // Value sets the value of the text field. func (t *TextAny[T]) Value(value *T) *TextAny[T] { t.value = value strValue := t.marshal(*value) t.strValue = &strValue t.textarea.SetValue(*t.strValue) return t } // Key sets the key of the text field. func (t *TextAny[T]) Key(key string) *TextAny[T] { t.key = key return t } // Title sets the title of the text field. func (t *TextAny[T]) Title(title string) *TextAny[T] { t.title = title return t } // Lines sets the number of lines to show of the text field. func (t *TextAny[T]) Lines(lines int) *TextAny[T] { t.textarea.SetHeight(lines) return t } // Description sets the description of the text field. func (t *TextAny[T]) Description(description string) *TextAny[T] { t.description = description return t } // CharLimit sets the character limit of the text field. func (t *TextAny[T]) CharLimit(charlimit int) *TextAny[T] { t.textarea.CharLimit = charlimit return t } // ShowLineNumbers sets whether or not to show line numbers. func (t *TextAny[T]) ShowLineNumbers(show bool) *TextAny[T] { t.textarea.ShowLineNumbers = show return t } // Placeholder sets the placeholder of the text field. func (t *TextAny[T]) Placeholder(str string) *TextAny[T] { t.textarea.Placeholder = str return t } // Validate sets the validation function of the text field. func (t *TextAny[T]) Validate(validate func(T) error) *TextAny[T] { t.validate = validate return t } const defaultEditor = "nano" // getEditor returns the editor command and arguments. func getEditor() (string, []string) { editor := strings.Fields(os.Getenv("EDITOR")) if len(editor) > 0 { return editor[0], editor[1:] } return defaultEditor, nil } // Editor specifies which editor to use. // // The first argument provided is used as the editor command (vim, nvim, nano, etc...) // The following (optional) arguments provided are passed as arguments to the editor command. func (t *TextAny[T]) Editor(editor ...string) *TextAny[T] { if len(editor) > 0 { t.editorCmd = editor[0] } if len(editor) > 1 { t.editorArgs = editor[1:] } return t } // EditorExtension specifies arguments to pass into the editor. func (t *TextAny[T]) EditorExtension(extension string) *TextAny[T] { t.editorExtension = extension return t } // Error returns the error of the text field. func (t *TextAny[T]) Error() error { return t.err } // Skip returns whether the textarea should be skipped or should be blocking. func (*TextAny[T]) Skip() bool { return false } // Zoom returns whether the note should be zoomed. func (*TextAny[T]) Zoom() bool { return false } // Focus focuses the text field. func (t *TextAny[T]) Focus() tea.Cmd { t.focused = true return t.textarea.Focus() } // Blur blurs the text field. func (t *TextAny[T]) Blur() tea.Cmd { t.focused = false *t.strValue = t.textarea.Value() t.textarea.Blur() *t.value, t.err = t.unmarshal(*t.strValue) if t.err == nil { t.err = t.validate(*t.value) } return nil } // KeyBinds returns the help message for the text field. func (t *TextAny[T]) KeyBinds() []key.Binding { return []key.Binding{t.keymap.NewLine, t.keymap.Editor, t.keymap.Prev, t.keymap.Submit, t.keymap.Next} } type updateValueMsg []byte // Init initializes the text field. func (t *TextAny[T]) Init() tea.Cmd { t.textarea.Blur() return nil } // Update updates the text field. func (t *TextAny[T]) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmds []tea.Cmd var cmd tea.Cmd t.textarea, cmd = t.textarea.Update(msg) cmds = append(cmds, cmd) *t.strValue = t.textarea.Value() switch msg := msg.(type) { case updateValueMsg: t.textarea.SetValue(string(msg)) t.textarea, cmd = t.textarea.Update(msg) cmds = append(cmds, cmd) *t.strValue = t.textarea.Value() case tea.KeyMsg: t.err = nil switch { case key.Matches(msg, t.keymap.Editor): ext := strings.TrimPrefix(t.editorExtension, ".") tmpFile, _ := os.CreateTemp(os.TempDir(), "*."+ext) cmd := exec.Command(t.editorCmd, append(t.editorArgs, tmpFile.Name())...) _ = os.WriteFile(tmpFile.Name(), []byte(t.textarea.Value()), 0600) cmds = append(cmds, tea.ExecProcess(cmd, func(error) tea.Msg { content, _ := os.ReadFile(tmpFile.Name()) _ = os.Remove(tmpFile.Name()) return updateValueMsg(content) })) case key.Matches(msg, t.keymap.Next, t.keymap.Submit): t.saveValue() if t.err != nil { return t, nil } cmds = append(cmds, huh.NextField) case key.Matches(msg, t.keymap.Prev): t.saveValue() if t.err != nil { return t, nil } cmds = append(cmds, huh.PrevField) } } return t, tea.Batch(cmds...) } func (t *TextAny[T]) saveValue() { strValue := t.textarea.Value() value, err := t.unmarshal(strValue) if err == nil { err = t.validate(value) if err != nil { *t.value = value } } t.err = err } func (t *TextAny[T]) activeStyles() *huh.FieldStyles { theme := t.theme if theme == nil { theme = huh.ThemeCharm() } if t.focused { return &theme.Focused } return &theme.Blurred } func (t *TextAny[T]) activeTextAreaStyles() *textarea.Style { if t.theme == nil { return &t.textarea.BlurredStyle } if t.focused { return &t.textarea.FocusedStyle } return &t.textarea.BlurredStyle } // View renders the text field. func (t *TextAny[T]) View() string { var styles = t.activeStyles() var textareaStyles = t.activeTextAreaStyles() // NB: since the method is on a pointer receiver these are being mutated. // Because this runs on every render this shouldn't matter in practice, // however. textareaStyles.Placeholder = styles.TextInput.Placeholder textareaStyles.Text = styles.TextInput.Text textareaStyles.Prompt = styles.TextInput.Prompt textareaStyles.CursorLine = styles.TextInput.Text t.textarea.Cursor.Style = styles.TextInput.Cursor var sb strings.Builder if t.title != "" { sb.WriteString(styles.Title.Render(t.title)) if t.err != nil { sb.WriteString(styles.ErrorIndicator.String()) } sb.WriteString("\n") } if t.description != "" { sb.WriteString(styles.Description.Render(t.description)) sb.WriteString("\n") } sb.WriteString(t.textarea.View()) return styles.Base.Render(sb.String()) } // Run runs the text field. func (t *TextAny[T]) Run() error { if t.accessible { return t.runAccessible() } return huh.Run(t) } // runAccessible runs an accessible text field. func (t *TextAny[T]) runAccessible() error { styles := t.activeStyles() fmt.Println(styles.Title.Render(t.title)) fmt.Println() *t.strValue = accessibility.PromptString("Input: ", func(input string) error { value, err := t.unmarshal(input) if err != nil { return err } if err := t.validate(value); err != nil { // Handle the error from t.validate, return it return err } if len(input) > t.textarea.CharLimit { return fmt.Errorf("Input cannot exceed %d characters", t.textarea.CharLimit) } return nil }) fmt.Println() return nil } // WithTheme sets the theme on a text field. func (t *TextAny[T]) WithTheme(theme *huh.Theme) huh.Field { if t.theme != nil { return t } t.theme = theme return t } // WithKeyMap sets the keymap on a text field. func (t *TextAny[T]) WithKeyMap(k *huh.KeyMap) huh.Field { t.keymap = k.Text t.textarea.KeyMap.InsertNewline.SetKeys(t.keymap.NewLine.Keys()...) return t } // WithAccessible sets the accessible mode of the text field. func (t *TextAny[T]) WithAccessible(accessible bool) huh.Field { t.accessible = accessible return t } // WithWidth sets the width of the text field. func (t *TextAny[T]) WithWidth(width int) huh.Field { t.width = width t.textarea.SetWidth(width - t.activeStyles().Base.GetHorizontalFrameSize()) return t } // WithHeight sets the height of the text field. func (t *TextAny[T]) WithHeight(height int) huh.Field { adjust := 0 if t.title != "" { adjust++ } if t.description != "" { adjust++ } t.textarea.SetHeight(height - t.activeStyles().Base.GetVerticalFrameSize() - adjust) return t } // WithPosition sets the position information of the text field. func (t *TextAny[T]) WithPosition(p huh.FieldPosition) huh.Field { t.keymap.Prev.SetEnabled(!p.IsFirst()) t.keymap.Next.SetEnabled(!p.IsLast()) t.keymap.Submit.SetEnabled(p.IsLast()) return t } // GetKey returns the key of the field. func (t *TextAny[T]) GetKey() string { return t.key } // GetValue returns the value of the field. func (t *TextAny[T]) GetValue() any { return *t.value } ```

Then, if I want support for int field to collect a port number, I can do:

func marshalInt(input int) string {
    return strconv.FormatInt(int64(input), 10)
}

func unmarshalInt(input string) (int, error) {
    return strconv.Atoi(input)
}

func validatePort(port int) error {
    if port < 0 || port > 65535 {
        return fmt.Errorf("Invalid port")
    }
    return nil
}

.
.
.

NewInputAny(marshalInt, unmarshalInt).
    Title("Port").
    Value(&portNumber).
    Validate(validatePort)

If I want to collect a single IPv4 address into a netip.Addr, I can do:

func marshalAddr(input netip.Addr) string {
    return input.String()
}

func unmarshalAddr(input string) (netip.Addr, error) {
    return netip.ParseAddr(input)
}

func validateAddrv4(addr netip.Addr) error {
    if !addr.Is4() {
        return fmt.Errorf("Address must be IPv4")
    }

    return nil
}

.
.
.

NewInputAny(marshalAddr, unmarshalAddr).
    Title("IP").
    Value(&m.balenaConfig.Proxy.IP).
    Validate(validateAddrv4),

If I want to collect a set of IP/Prefix, one per line, into a netip.Prefix and ensure they are valid and don't overlap, I can do:

func marshalIPPrefixArray(input []netip.Prefix) string {
    // Convert to a string with one CIDR per line
    var s string
    for _, prefix := range input {
        s += prefix.String() + "\n"
    }

    return s
}

var netipParseCleanup = regexp.MustCompile(`netip.ParsePrefix\(\"(.*)\"\): (.*)`)

func unmarshalIPPrefixArray(input string) ([]netip.Prefix, error) {
    // Split the string into lines
    lines := strings.Split(input, "\n")

    // Parse each line as a CIDR
    var prefixes []netip.Prefix
    for _, line := range lines {
        if line == "" {
            continue
        }
        prefix, err := netip.ParsePrefix(line)
        if err != nil {
            // Clean up the error message
            errStr := netipParseCleanup.ReplaceAllString(err.Error(), "'$1': $2")
            return nil, fmt.Errorf(errStr)
        }
        prefixes = append(prefixes, prefix)
    }

    return prefixes, nil
}

func validateIPPrefixArray(prefixes []netip.Prefix) error {
    // Check for empty array
    if len(prefixes) == 0 {
        return fmt.Errorf("IP Prefix array cannot be empty")
    }

    // Check for invalid prefixes
    for _, prefix := range prefixes {
        if !prefix.IsValid() {
            return fmt.Errorf("Invalid IP Prefix: %s", prefix)
        }
    }

    // Check for overlapping prefixes
    for i, prefix := range prefixes {
        for j, otherPrefix := range prefixes {
            if i != j && prefix.Overlaps(otherPrefix) {
                return fmt.Errorf("Overlapping IP Prefixes: %s and %s", prefix, otherPrefix)
            }
        }
    }

    return nil
}
.
.
.

NewTextAny(marshalIPPrefixArray, unmarshalIPPrefixArray).
    Title("CIDR(s)").
    Description("List of IPs in CIDR (1.2.3.4/24) format, one per line").
    Value(&m.balenaConfig.Ethernet[selectedInterface].IPv4.Addresses).
    Validate(validateIPPrefixArray),
intercepted16 commented 4 days ago

@shaunco Thanks a lot for that temporary solution!