// Package ssh_config provides tools for manipulating SSH config files.
//
// Importantly, this parser attempts to preserve comments in a given file, so
// you can manipulate a `ssh_config` file from a program, if your heart desires.
//
// The Get() and GetStrict() functions will attempt to read values from
// $HOME/.ssh/config, falling back to /etc/ssh/ssh_config. The first argument is
// the host name to match on ("example.com"), and the second argument is the key
// you want to retrieve ("Port"). The keywords are case insensitive.
//
// 		port := ssh_config.Get("myhost", "Port")
//
// You can also manipulate an SSH config file and then print it or write it back
// to disk.
//
//	f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config"))
//	cfg, _ := ssh_config.Decode(f)
//	for _, host := range cfg.Hosts {
//		fmt.Println("patterns:", host.Patterns)
//		for _, node := range host.Nodes {
//			fmt.Println(node.String())
//		}
//	}
//
//	// Write the cfg back to disk:
//	fmt.Println(cfg.String())
//
// BUG: the Match directive is currently unsupported; parsing a config with
// a Match directive will trigger an error.
package ssh_config

import (
	"bytes"
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"os"
	osuser "os/user"
	"path/filepath"
	"regexp"
	"runtime"
	"strings"
	"sync"
)

const version = "1.0"

var _ = version

type configFinder func() string

// UserSettings checks ~/.ssh and /etc/ssh for configuration files. The config
// files are parsed and cached the first time Get() or GetStrict() is called.
type UserSettings struct {
	IgnoreErrors       bool
	systemConfig       *Config
	systemConfigFinder configFinder
	userConfig         *Config
	userConfigFinder   configFinder
	loadConfigs        sync.Once
	onceErr            error
}

func homedir() string {
	user, err := osuser.Current()
	if err == nil {
		return user.HomeDir
	} else {
		return os.Getenv("HOME")
	}
}

func userConfigFinder() string {
	return filepath.Join(homedir(), ".ssh", "config")
}

// DefaultUserSettings is the default UserSettings and is used by Get and
// GetStrict. It checks both $HOME/.ssh/config and /etc/ssh/ssh_config for keys,
// and it will return parse errors (if any) instead of swallowing them.
var DefaultUserSettings = &UserSettings{
	IgnoreErrors:       false,
	systemConfigFinder: systemConfigFinder,
	userConfigFinder:   userConfigFinder,
}

func systemConfigFinder() string {
	return filepath.Join("/", "etc", "ssh", "ssh_config")
}

func findVal(c *Config, alias, key string) (string, error) {
	if c == nil {
		return "", nil
	}
	val, err := c.Get(alias, key)
	if err != nil || val == "" {
		return "", err
	}
	if err := validate(key, val); err != nil {
		return "", err
	}
	return val, nil
}

func findAll(c *Config, alias, key string) ([]string, error) {
	if c == nil {
		return nil, nil
	}
	return c.GetAll(alias, key)
}

// Get finds the first value for key within a declaration that matches the
// alias. Get returns the empty string if no value was found, or if IgnoreErrors
// is false and we could not parse the configuration file. Use GetStrict to
// disambiguate the latter cases.
//
// The match for key is case insensitive.
//
// Get is a wrapper around DefaultUserSettings.Get.
func Get(alias, key string) string {
	return DefaultUserSettings.Get(alias, key)
}

// GetAll retrieves zero or more directives for key for the given alias. GetAll
// returns nil if no value was found, or if IgnoreErrors is false and we could
// not parse the configuration file. Use GetAllStrict to disambiguate the
// latter cases.
//
// In most cases you want to use Get or GetStrict, which returns a single value.
// However, a subset of ssh configuration values (IdentityFile, for example)
// allow you to specify multiple directives.
//
// The match for key is case insensitive.
//
// GetAll is a wrapper around DefaultUserSettings.GetAll.
func GetAll(alias, key string) []string {
	return DefaultUserSettings.GetAll(alias, key)
}

// GetStrict finds the first value for key within a declaration that matches the
// alias. If key has a default value and no matching configuration is found, the
// default will be returned. For more information on default values and the way
// patterns are matched, see the manpage for ssh_config.
//
// The returned error will be non-nil if and only if a user's configuration file
// or the system configuration file could not be parsed, and u.IgnoreErrors is
// false.
//
// GetStrict is a wrapper around DefaultUserSettings.GetStrict.
func GetStrict(alias, key string) (string, error) {
	return DefaultUserSettings.GetStrict(alias, key)
}

// GetAllStrict retrieves zero or more directives for key for the given alias.
//
// In most cases you want to use Get or GetStrict, which returns a single value.
// However, a subset of ssh configuration values (IdentityFile, for example)
// allow you to specify multiple directives.
//
// The returned error will be non-nil if and only if a user's configuration file
// or the system configuration file could not be parsed, and u.IgnoreErrors is
// false.
//
// GetAllStrict is a wrapper around DefaultUserSettings.GetAllStrict.
func GetAllStrict(alias, key string) ([]string, error) {
	return DefaultUserSettings.GetAllStrict(alias, key)
}

// Get finds the first value for key within a declaration that matches the
// alias. Get returns the empty string if no value was found, or if IgnoreErrors
// is false and we could not parse the configuration file. Use GetStrict to
// disambiguate the latter cases.
//
// The match for key is case insensitive.
func (u *UserSettings) Get(alias, key string) string {
	val, err := u.GetStrict(alias, key)
	if err != nil {
		return ""
	}
	return val
}

// GetAll retrieves zero or more directives for key for the given alias. GetAll
// returns nil if no value was found, or if IgnoreErrors is false and we could
// not parse the configuration file. Use GetStrict to disambiguate the latter
// cases.
//
// The match for key is case insensitive.
func (u *UserSettings) GetAll(alias, key string) []string {
	val, _ := u.GetAllStrict(alias, key)
	return val
}

// GetStrict finds the first value for key within a declaration that matches the
// alias. If key has a default value and no matching configuration is found, the
// default will be returned. For more information on default values and the way
// patterns are matched, see the manpage for ssh_config.
//
// error will be non-nil if and only if a user's configuration file or the
// system configuration file could not be parsed, and u.IgnoreErrors is false.
func (u *UserSettings) GetStrict(alias, key string) (string, error) {
	u.doLoadConfigs()
	//lint:ignore S1002 I prefer it this way
	if u.onceErr != nil && u.IgnoreErrors == false {
		return "", u.onceErr
	}
	val, err := findVal(u.userConfig, alias, key)
	if err != nil || val != "" {
		return val, err
	}
	val2, err2 := findVal(u.systemConfig, alias, key)
	if err2 != nil || val2 != "" {
		return val2, err2
	}
	return Default(key), nil
}

// GetAllStrict retrieves zero or more directives for key for the given alias.
// If key has a default value and no matching configuration is found, the
// default will be returned. For more information on default values and the way
// patterns are matched, see the manpage for ssh_config.
//
// The returned error will be non-nil if and only if a user's configuration file
// or the system configuration file could not be parsed, and u.IgnoreErrors is
// false.
func (u *UserSettings) GetAllStrict(alias, key string) ([]string, error) {
	u.doLoadConfigs()
	//lint:ignore S1002 I prefer it this way
	if u.onceErr != nil && u.IgnoreErrors == false {
		return nil, u.onceErr
	}
	val, err := findAll(u.userConfig, alias, key)
	if err != nil || val != nil {
		return val, err
	}
	val2, err2 := findAll(u.systemConfig, alias, key)
	if err2 != nil || val2 != nil {
		return val2, err2
	}
	// TODO: IdentityFile has multiple default values that we should return.
	if def := Default(key); def != "" {
		return []string{def}, nil
	}
	return []string{}, nil
}

func (u *UserSettings) doLoadConfigs() {
	u.loadConfigs.Do(func() {
		// can't parse user file, that's ok.
		var filename string
		if u.userConfigFinder == nil {
			filename = userConfigFinder()
		} else {
			filename = u.userConfigFinder()
		}
		var err error
		u.userConfig, err = parseFile(filename)
		//lint:ignore S1002 I prefer it this way
		if err != nil && os.IsNotExist(err) == false {
			u.onceErr = err
			return
		}
		if u.systemConfigFinder == nil {
			filename = systemConfigFinder()
		} else {
			filename = u.systemConfigFinder()
		}
		u.systemConfig, err = parseFile(filename)
		//lint:ignore S1002 I prefer it this way
		if err != nil && os.IsNotExist(err) == false {
			u.onceErr = err
			return
		}
	})
}

func parseFile(filename string) (*Config, error) {
	return parseWithDepth(filename, 0)
}

func parseWithDepth(filename string, depth uint8) (*Config, error) {
	b, err := ioutil.ReadFile(filename)
	if err != nil {
		return nil, err
	}
	return decodeBytes(b, isSystem(filename), depth)
}

func isSystem(filename string) bool {
	// TODO: not sure this is the best way to detect a system repo
	return strings.HasPrefix(filepath.Clean(filename), "/etc/ssh")
}

// Decode reads r into a Config, or returns an error if r could not be parsed as
// an SSH config file.
func Decode(r io.Reader) (*Config, error) {
	b, err := ioutil.ReadAll(r)
	if err != nil {
		return nil, err
	}
	return decodeBytes(b, false, 0)
}

func decodeBytes(b []byte, system bool, depth uint8) (c *Config, err error) {
	defer func() {
		if r := recover(); r != nil {
			if _, ok := r.(runtime.Error); ok {
				panic(r)
			}
			if e, ok := r.(error); ok && e == ErrDepthExceeded {
				err = e
				return
			}
			err = errors.New(r.(string))
		}
	}()

	c = parseSSH(lexSSH(b), system, depth)
	return c, err
}

// Config represents an SSH config file.
type Config struct {
	// A list of hosts to match against. The file begins with an implicit
	// "Host *" declaration matching all hosts.
	Hosts    []*Host
	depth    uint8
	position Position
}

// Get finds the first value in the configuration that matches the alias and
// contains key. Get returns the empty string if no value was found, or if the
// Config contains an invalid conditional Include value.
//
// The match for key is case insensitive.
func (c *Config) Get(alias, key string) (string, error) {
	lowerKey := strings.ToLower(key)
	for _, host := range c.Hosts {
		if !host.Matches(alias) {
			continue
		}
		for _, node := range host.Nodes {
			switch t := node.(type) {
			case *Empty:
				continue
			case *KV:
				// "keys are case insensitive" per the spec
				lkey := strings.ToLower(t.Key)
				if lkey == "match" {
					panic("can't handle Match directives")
				}
				if lkey == lowerKey {
					return t.Value, nil
				}
			case *Include:
				val := t.Get(alias, key)
				if val != "" {
					return val, nil
				}
			default:
				return "", fmt.Errorf("unknown Node type %v", t)
			}
		}
	}
	return "", nil
}

// GetAll returns all values in the configuration that match the alias and
// contains key, or nil if none are present.
func (c *Config) GetAll(alias, key string) ([]string, error) {
	lowerKey := strings.ToLower(key)
	all := []string(nil)
	for _, host := range c.Hosts {
		if !host.Matches(alias) {
			continue
		}
		for _, node := range host.Nodes {
			switch t := node.(type) {
			case *Empty:
				continue
			case *KV:
				// "keys are case insensitive" per the spec
				lkey := strings.ToLower(t.Key)
				if lkey == "match" {
					panic("can't handle Match directives")
				}
				if lkey == lowerKey {
					all = append(all, t.Value)
				}
			case *Include:
				val, _ := t.GetAll(alias, key)
				if len(val) > 0 {
					all = append(all, val...)
				}
			default:
				return nil, fmt.Errorf("unknown Node type %v", t)
			}
		}
	}

	return all, nil
}

// String returns a string representation of the Config file.
func (c Config) String() string {
	return marshal(c).String()
}

func (c Config) MarshalText() ([]byte, error) {
	return marshal(c).Bytes(), nil
}

func marshal(c Config) *bytes.Buffer {
	var buf bytes.Buffer
	for i := range c.Hosts {
		buf.WriteString(c.Hosts[i].String())
	}
	return &buf
}

// Pattern is a pattern in a Host declaration. Patterns are read-only values;
// create a new one with NewPattern().
type Pattern struct {
	str   string // Its appearance in the file, not the value that gets compiled.
	regex *regexp.Regexp
	not   bool // True if this is a negated match
}

// String prints the string representation of the pattern.
func (p Pattern) String() string {
	return p.str
}

// Copied from regexp.go with * and ? removed.
var specialBytes = []byte(`\.+()|[]{}^$`)

func special(b byte) bool {
	return bytes.IndexByte(specialBytes, b) >= 0
}

// NewPattern creates a new Pattern for matching hosts. NewPattern("*") creates
// a Pattern that matches all hosts.
//
// From the manpage, a pattern consists of zero or more non-whitespace
// characters, `*' (a wildcard that matches zero or more characters), or `?' (a
// wildcard that matches exactly one character). For example, to specify a set
// of declarations for any host in the ".co.uk" set of domains, the following
// pattern could be used:
//
//	Host *.co.uk
//
// The following pattern would match any host in the 192.168.0.[0-9] network range:
//
//	Host 192.168.0.?
func NewPattern(s string) (*Pattern, error) {
	if s == "" {
		return nil, errors.New("ssh_config: empty pattern")
	}
	negated := false
	if s[0] == '!' {
		negated = true
		s = s[1:]
	}
	var buf bytes.Buffer
	buf.WriteByte('^')
	for i := 0; i < len(s); i++ {
		// A byte loop is correct because all metacharacters are ASCII.
		switch b := s[i]; b {
		case '*':
			buf.WriteString(".*")
		case '?':
			buf.WriteString(".?")
		default:
			// borrowing from QuoteMeta here.
			if special(b) {
				buf.WriteByte('\\')
			}
			buf.WriteByte(b)
		}
	}
	buf.WriteByte('$')
	r, err := regexp.Compile(buf.String())
	if err != nil {
		return nil, err
	}
	return &Pattern{str: s, regex: r, not: negated}, nil
}

// Host describes a Host directive and the keywords that follow it.
type Host struct {
	// A list of host patterns that should match this host.
	Patterns []*Pattern
	// A Node is either a key/value pair or a comment line.
	Nodes []Node
	// EOLComment is the comment (if any) terminating the Host line.
	EOLComment   string
	hasEquals    bool
	leadingSpace int // TODO: handle spaces vs tabs here.
	// The file starts with an implicit "Host *" declaration.
	implicit bool
}

// Matches returns true if the Host matches for the given alias. For
// a description of the rules that provide a match, see the manpage for
// ssh_config.
func (h *Host) Matches(alias string) bool {
	found := false
	for i := range h.Patterns {
		if h.Patterns[i].regex.MatchString(alias) {
			if h.Patterns[i].not {
				// Negated match. "A pattern entry may be negated by prefixing
				// it with an exclamation mark (`!'). If a negated entry is
				// matched, then the Host entry is ignored, regardless of
				// whether any other patterns on the line match. Negated matches
				// are therefore useful to provide exceptions for wildcard
				// matches."
				return false
			}
			found = true
		}
	}
	return found
}

// String prints h as it would appear in a config file. Minor tweaks may be
// present in the whitespace in the printed file.
func (h *Host) String() string {
	var buf bytes.Buffer
	//lint:ignore S1002 I prefer to write it this way
	if h.implicit == false {
		buf.WriteString(strings.Repeat(" ", int(h.leadingSpace)))
		buf.WriteString("Host")
		if h.hasEquals {
			buf.WriteString(" = ")
		} else {
			buf.WriteString(" ")
		}
		for i, pat := range h.Patterns {
			buf.WriteString(pat.String())
			if i < len(h.Patterns)-1 {
				buf.WriteString(" ")
			}
		}
		if h.EOLComment != "" {
			buf.WriteString(" #")
			buf.WriteString(h.EOLComment)
		}
		buf.WriteByte('\n')
	}
	for i := range h.Nodes {
		buf.WriteString(h.Nodes[i].String())
		buf.WriteByte('\n')
	}
	return buf.String()
}

// Node represents a line in a Config.
type Node interface {
	Pos() Position
	String() string
}

// KV is a line in the config file that contains a key, a value, and possibly
// a comment.
type KV struct {
	Key          string
	Value        string
	Comment      string
	hasEquals    bool
	leadingSpace int // Space before the key. TODO handle spaces vs tabs.
	position     Position
}

// Pos returns k's Position.
func (k *KV) Pos() Position {
	return k.position
}

// String prints k as it was parsed in the config file. There may be slight
// changes to the whitespace between values.
func (k *KV) String() string {
	if k == nil {
		return ""
	}
	equals := " "
	if k.hasEquals {
		equals = " = "
	}
	line := fmt.Sprintf("%s%s%s%s", strings.Repeat(" ", int(k.leadingSpace)), k.Key, equals, k.Value)
	if k.Comment != "" {
		line += " #" + k.Comment
	}
	return line
}

// Empty is a line in the config file that contains only whitespace or comments.
type Empty struct {
	Comment      string
	leadingSpace int // TODO handle spaces vs tabs.
	position     Position
}

// Pos returns e's Position.
func (e *Empty) Pos() Position {
	return e.position
}

// String prints e as it was parsed in the config file.
func (e *Empty) String() string {
	if e == nil {
		return ""
	}
	if e.Comment == "" {
		return ""
	}
	return fmt.Sprintf("%s#%s", strings.Repeat(" ", int(e.leadingSpace)), e.Comment)
}

// Include holds the result of an Include directive, including the config files
// that have been parsed as part of that directive. At most 5 levels of Include
// statements will be parsed.
type Include struct {
	// Comment is the contents of any comment at the end of the Include
	// statement.
	Comment string
	// an include directive can include several different files, and wildcards
	directives []string

	mu sync.Mutex
	// 1:1 mapping between matches and keys in files array; matches preserves
	// ordering
	matches []string
	// actual filenames are listed here
	files        map[string]*Config
	leadingSpace int
	position     Position
	depth        uint8
	hasEquals    bool
}

const maxRecurseDepth = 5

// ErrDepthExceeded is returned if too many Include directives are parsed.
// Usually this indicates a recursive loop (an Include directive pointing to the
// file it contains).
var ErrDepthExceeded = errors.New("ssh_config: max recurse depth exceeded")

func removeDups(arr []string) []string {
	// Use map to record duplicates as we find them.
	encountered := make(map[string]bool, len(arr))
	result := make([]string, 0)

	for v := range arr {
		//lint:ignore S1002 I prefer it this way
		if encountered[arr[v]] == false {
			encountered[arr[v]] = true
			result = append(result, arr[v])
		}
	}
	return result
}

// NewInclude creates a new Include with a list of file globs to include.
// Configuration files are parsed greedily (e.g. as soon as this function runs).
// Any error encountered while parsing nested configuration files will be
// returned.
func NewInclude(directives []string, hasEquals bool, pos Position, comment string, system bool, depth uint8) (*Include, error) {
	if depth > maxRecurseDepth {
		return nil, ErrDepthExceeded
	}
	inc := &Include{
		Comment:      comment,
		directives:   directives,
		files:        make(map[string]*Config),
		position:     pos,
		leadingSpace: pos.Col - 1,
		depth:        depth,
		hasEquals:    hasEquals,
	}
	// no need for inc.mu.Lock() since nothing else can access this inc
	matches := make([]string, 0)
	for i := range directives {
		var path string
		if filepath.IsAbs(directives[i]) {
			path = directives[i]
		} else if system {
			path = filepath.Join("/etc/ssh", directives[i])
		} else {
			path = filepath.Join(homedir(), ".ssh", directives[i])
		}
		theseMatches, err := filepath.Glob(path)
		if err != nil {
			return nil, err
		}
		matches = append(matches, theseMatches...)
	}
	matches = removeDups(matches)
	inc.matches = matches
	for i := range matches {
		config, err := parseWithDepth(matches[i], depth)
		if err != nil {
			return nil, err
		}
		inc.files[matches[i]] = config
	}
	return inc, nil
}

// Pos returns the position of the Include directive in the larger file.
func (i *Include) Pos() Position {
	return i.position
}

// Get finds the first value in the Include statement matching the alias and the
// given key.
func (inc *Include) Get(alias, key string) string {
	inc.mu.Lock()
	defer inc.mu.Unlock()
	// TODO: we search files in any order which is not correct
	for i := range inc.matches {
		cfg := inc.files[inc.matches[i]]
		if cfg == nil {
			panic("nil cfg")
		}
		val, err := cfg.Get(alias, key)
		if err == nil && val != "" {
			return val
		}
	}
	return ""
}

// GetAll finds all values in the Include statement matching the alias and the
// given key.
func (inc *Include) GetAll(alias, key string) ([]string, error) {
	inc.mu.Lock()
	defer inc.mu.Unlock()
	var vals []string

	// TODO: we search files in any order which is not correct
	for i := range inc.matches {
		cfg := inc.files[inc.matches[i]]
		if cfg == nil {
			panic("nil cfg")
		}
		val, err := cfg.GetAll(alias, key)
		if err == nil && len(val) != 0 {
			// In theory if SupportsMultiple was false for this key we could
			// stop looking here. But the caller has asked us to find all
			// instances of the keyword (and could use Get() if they wanted) so
			// let's keep looking.
			vals = append(vals, val...)
		}
	}
	return vals, nil
}

// String prints out a string representation of this Include directive. Note
// included Config files are not printed as part of this representation.
func (inc *Include) String() string {
	equals := " "
	if inc.hasEquals {
		equals = " = "
	}
	line := fmt.Sprintf("%sInclude%s%s", strings.Repeat(" ", int(inc.leadingSpace)), equals, strings.Join(inc.directives, " "))
	if inc.Comment != "" {
		line += " #" + inc.Comment
	}
	return line
}

var matchAll *Pattern

func init() {
	var err error
	matchAll, err = NewPattern("*")
	if err != nil {
		panic(err)
	}
}

func newConfig() *Config {
	return &Config{
		Hosts: []*Host{
			&Host{
				implicit: true,
				Patterns: []*Pattern{matchAll},
				Nodes:    make([]Node, 0),
			},
		},
		depth: 0,
	}
}