cmd/puppeth: accept ssh identity in the server string (#17407)
* cmd/puppeth: Accept identityfile in the server string with fallback to id_rsa * cmd/puppeth: code polishes + fix heath check double ports
This commit is contained in:
		
				
					committed by
					
						 Péter Szilágyi
						Péter Szilágyi
					
				
			
			
				
	
			
			
			
						parent
						
							1de9ada401
						
					
				
				
					commit
					7d38d53ae4
				
			| @@ -45,33 +45,44 @@ type sshClient struct { | |||||||
|  |  | ||||||
| // dial establishes an SSH connection to a remote node using the current user and | // dial establishes an SSH connection to a remote node using the current user and | ||||||
| // the user's configured private RSA key. If that fails, password authentication | // the user's configured private RSA key. If that fails, password authentication | ||||||
| // is fallen back to. The caller may override the login user via user@server:port. | // is fallen back to. server can be a string like user:identity@server:port. | ||||||
| func dial(server string, pubkey []byte) (*sshClient, error) { | func dial(server string, pubkey []byte) (*sshClient, error) { | ||||||
| 	// Figure out a label for the server and a logger | 	// Figure out username, identity, hostname and port | ||||||
| 	label := server | 	hostname := "" | ||||||
| 	if strings.Contains(label, ":") { | 	hostport := server | ||||||
| 		label = label[:strings.Index(label, ":")] | 	username := "" | ||||||
| 	} | 	identity := "id_rsa" // default | ||||||
| 	login := "" |  | ||||||
| 	if strings.Contains(server, "@") { | 	if strings.Contains(server, "@") { | ||||||
| 		login = label[:strings.Index(label, "@")] | 		prefix := server[:strings.Index(server, "@")] | ||||||
| 		label = label[strings.Index(label, "@")+1:] | 		if strings.Contains(prefix, ":") { | ||||||
| 		server = server[strings.Index(server, "@")+1:] | 			username = prefix[:strings.Index(prefix, ":")] | ||||||
|  | 			identity = prefix[strings.Index(prefix, ":")+1:] | ||||||
|  | 		} else { | ||||||
|  | 			username = prefix | ||||||
|  | 		} | ||||||
|  | 		hostport = server[strings.Index(server, "@")+1:] | ||||||
| 	} | 	} | ||||||
| 	logger := log.New("server", label) | 	if strings.Contains(hostport, ":") { | ||||||
|  | 		hostname = hostport[:strings.Index(hostport, ":")] | ||||||
|  | 	} else { | ||||||
|  | 		hostname = hostport | ||||||
|  | 		hostport += ":22" | ||||||
|  | 	} | ||||||
|  | 	logger := log.New("server", server) | ||||||
| 	logger.Debug("Attempting to establish SSH connection") | 	logger.Debug("Attempting to establish SSH connection") | ||||||
|  |  | ||||||
| 	user, err := user.Current() | 	user, err := user.Current() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	if login == "" { | 	if username == "" { | ||||||
| 		login = user.Username | 		username = user.Username | ||||||
| 	} | 	} | ||||||
| 	// Configure the supported authentication methods (private key and password) | 	// Configure the supported authentication methods (private key and password) | ||||||
| 	var auths []ssh.AuthMethod | 	var auths []ssh.AuthMethod | ||||||
|  |  | ||||||
| 	path := filepath.Join(user.HomeDir, ".ssh", "id_rsa") | 	path := filepath.Join(user.HomeDir, ".ssh", identity) | ||||||
| 	if buf, err := ioutil.ReadFile(path); err != nil { | 	if buf, err := ioutil.ReadFile(path); err != nil { | ||||||
| 		log.Warn("No SSH key, falling back to passwords", "path", path, "err", err) | 		log.Warn("No SSH key, falling back to passwords", "path", path, "err", err) | ||||||
| 	} else { | 	} else { | ||||||
| @@ -94,14 +105,14 @@ func dial(server string, pubkey []byte) (*sshClient, error) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	auths = append(auths, ssh.PasswordCallback(func() (string, error) { | 	auths = append(auths, ssh.PasswordCallback(func() (string, error) { | ||||||
| 		fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", login, server) | 		fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", username, server) | ||||||
| 		blob, err := terminal.ReadPassword(int(os.Stdin.Fd())) | 		blob, err := terminal.ReadPassword(int(os.Stdin.Fd())) | ||||||
|  |  | ||||||
| 		fmt.Println() | 		fmt.Println() | ||||||
| 		return string(blob), err | 		return string(blob), err | ||||||
| 	})) | 	})) | ||||||
| 	// Resolve the IP address of the remote server | 	// Resolve the IP address of the remote server | ||||||
| 	addr, err := net.LookupHost(label) | 	addr, err := net.LookupHost(hostname) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @@ -109,10 +120,7 @@ func dial(server string, pubkey []byte) (*sshClient, error) { | |||||||
| 		return nil, errors.New("no IPs associated with domain") | 		return nil, errors.New("no IPs associated with domain") | ||||||
| 	} | 	} | ||||||
| 	// Try to dial in to the remote server | 	// Try to dial in to the remote server | ||||||
| 	logger.Trace("Dialing remote SSH server", "user", login) | 	logger.Trace("Dialing remote SSH server", "user", username) | ||||||
| 	if !strings.Contains(server, ":") { |  | ||||||
| 		server += ":22" |  | ||||||
| 	} |  | ||||||
| 	keycheck := func(hostname string, remote net.Addr, key ssh.PublicKey) error { | 	keycheck := func(hostname string, remote net.Addr, key ssh.PublicKey) error { | ||||||
| 		// If no public key is known for SSH, ask the user to confirm | 		// If no public key is known for SSH, ask the user to confirm | ||||||
| 		if pubkey == nil { | 		if pubkey == nil { | ||||||
| @@ -139,13 +147,13 @@ func dial(server string, pubkey []byte) (*sshClient, error) { | |||||||
| 		// We have a mismatch, forbid connecting | 		// We have a mismatch, forbid connecting | ||||||
| 		return errors.New("ssh key mismatch, readd the machine to update") | 		return errors.New("ssh key mismatch, readd the machine to update") | ||||||
| 	} | 	} | ||||||
| 	client, err := ssh.Dial("tcp", server, &ssh.ClientConfig{User: login, Auth: auths, HostKeyCallback: keycheck}) | 	client, err := ssh.Dial("tcp", hostport, &ssh.ClientConfig{User: username, Auth: auths, HostKeyCallback: keycheck}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	// Connection established, return our utility wrapper | 	// Connection established, return our utility wrapper | ||||||
| 	c := &sshClient{ | 	c := &sshClient{ | ||||||
| 		server:  label, | 		server:  hostname, | ||||||
| 		address: addr[0], | 		address: addr[0], | ||||||
| 		pubkey:  pubkey, | 		pubkey:  pubkey, | ||||||
| 		client:  client, | 		client:  client, | ||||||
|   | |||||||
| @@ -62,14 +62,14 @@ func (w *wizard) manageServers() { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // makeServer reads a single line from stdin and interprets it as a hostname to | // makeServer reads a single line from stdin and interprets it as | ||||||
| // connect to. It tries to establish a new SSH session and also executing some | // username:identity@hostname to connect to. It tries to establish a | ||||||
| // baseline validations. | // new SSH session and also executing some baseline validations. | ||||||
| // | // | ||||||
| // If connection succeeds, the server is added to the wizards configs! | // If connection succeeds, the server is added to the wizards configs! | ||||||
| func (w *wizard) makeServer() string { | func (w *wizard) makeServer() string { | ||||||
| 	fmt.Println() | 	fmt.Println() | ||||||
| 	fmt.Println("Please enter remote server's address:") | 	fmt.Println("What is the remote server's address ([username[:identity]@]hostname[:port])?") | ||||||
|  |  | ||||||
| 	// Read and dial the server to ensure docker is present | 	// Read and dial the server to ensure docker is present | ||||||
| 	input := w.readString() | 	input := w.readString() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user