0
0
mirror of https://github.com/bpg/terraform-provider-proxmox.git synced 2025-06-30 10:33:46 +00:00
terraform-provider-proxmox/proxmox/virtual_environment_nodes.go
zoop 9fa92423b5
feat: SSH-Agent Support (#306)
* chore: add agent configuration bool

* feat: add ssh-agent authentication mechanism for linux

* chore: make sure ssh-agent auth is only executed on linux

* chore: add ssh user override

* chore: add ssh configuration block, check ssh config during VirtualEnvironmentClient creation

* fix: handle case of empty ssh config block

* chore: add ssh password auth fallback logic

* fix: remove not needed runtime

* fix linter errors & re-format

* allow ssh agent on all POSIX systems

* add `agent_socket` parameter

* update docs and examples

---------

Co-authored-by: zoop <zoop@zoop.li>
Co-authored-by: Pavel Boldyrev <627562+bpg@users.noreply.github.com>
2023-05-22 13:34:24 -04:00

380 lines
9.4 KiB
Go

/*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at https://mozilla.org/MPL/2.0/.
*/
package proxmox
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"path"
"sort"
"strings"
"time"
"github.com/hashicorp/terraform-plugin-log/tflog"
"github.com/skeema/knownhosts"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
// ExecuteNodeCommands executes commands on a given node.
func (c *VirtualEnvironmentClient) ExecuteNodeCommands(
ctx context.Context,
nodeName string,
commands []string,
) error {
closeOrLogError := CloseOrLogError(ctx)
sshClient, err := c.OpenNodeShell(ctx, nodeName)
if err != nil {
return err
}
defer closeOrLogError(sshClient)
sshSession, err := sshClient.NewSession()
if err != nil {
return err
}
defer closeOrLogError(sshSession)
script := strings.Join(commands, " && \\\n")
output, err := sshSession.CombinedOutput(
fmt.Sprintf(
"/bin/bash -c '%s'",
strings.ReplaceAll(script, "'", "'\"'\"'"),
),
)
if err != nil {
return errors.New(string(output))
}
return nil
}
// GetNodeIP retrieves the IP address of a node.
func (c *VirtualEnvironmentClient) GetNodeIP(
ctx context.Context,
nodeName string,
) (*string, error) {
networkDevices, err := c.ListNodeNetworkDevices(ctx, nodeName)
if err != nil {
return nil, err
}
nodeAddress := ""
for _, d := range networkDevices {
if d.Address != nil {
nodeAddress = *d.Address
break
}
}
if nodeAddress == "" {
return nil, fmt.Errorf("failed to determine the IP address of node \"%s\"", nodeName)
}
nodeAddressParts := strings.Split(nodeAddress, "/")
return &nodeAddressParts[0], nil
}
// GetNodeTime retrieves the time information for a node.
func (c *VirtualEnvironmentClient) GetNodeTime(
ctx context.Context,
nodeName string,
) (*VirtualEnvironmentNodeGetTimeResponseData, error) {
resBody := &VirtualEnvironmentNodeGetTimeResponseBody{}
err := c.DoRequest(
ctx,
http.MethodGet,
fmt.Sprintf("nodes/%s/time", url.PathEscape(nodeName)),
nil,
resBody,
)
if err != nil {
return nil, err
}
if resBody.Data == nil {
return nil, errors.New("the server did not include a data object in the response")
}
return resBody.Data, nil
}
// GetNodeTaskStatus retrieves the status of a node task.
func (c *VirtualEnvironmentClient) GetNodeTaskStatus(
ctx context.Context,
nodeName string,
upid string,
) (*VirtualEnvironmentNodeGetTaskStatusResponseData, error) {
resBody := &VirtualEnvironmentNodeGetTaskStatusResponseBody{}
err := c.DoRequest(
ctx,
http.MethodGet,
fmt.Sprintf("nodes/%s/tasks/%s/status", url.PathEscape(nodeName), url.PathEscape(upid)),
nil,
resBody,
)
if err != nil {
return nil, err
}
if resBody.Data == nil {
return nil, errors.New("the server did not include a data object in the response")
}
return resBody.Data, nil
}
// ListNodeNetworkDevices retrieves a list of network devices for a specific nodes.
func (c *VirtualEnvironmentClient) ListNodeNetworkDevices(
ctx context.Context,
nodeName string,
) ([]*VirtualEnvironmentNodeNetworkDeviceListResponseData, error) {
resBody := &VirtualEnvironmentNodeNetworkDeviceListResponseBody{}
err := c.DoRequest(
ctx,
http.MethodGet,
fmt.Sprintf("nodes/%s/network", url.PathEscape(nodeName)),
nil,
resBody,
)
if err != nil {
return nil, err
}
if resBody.Data == nil {
return nil, errors.New("the server did not include a data object in the response")
}
sort.Slice(resBody.Data, func(i, j int) bool {
return resBody.Data[i].Priority < resBody.Data[j].Priority
})
return resBody.Data, nil
}
// ListNodes retrieves a list of nodes.
func (c *VirtualEnvironmentClient) ListNodes(
ctx context.Context,
) ([]*VirtualEnvironmentNodeListResponseData, error) {
resBody := &VirtualEnvironmentNodeListResponseBody{}
err := c.DoRequest(ctx, http.MethodGet, "nodes", nil, resBody)
if err != nil {
return nil, err
}
if resBody.Data == nil {
return nil, errors.New("the server did not include a data object in the response")
}
sort.Slice(resBody.Data, func(i, j int) bool {
return resBody.Data[i].Name < resBody.Data[j].Name
})
return resBody.Data, nil
}
// OpenNodeShell establishes a new SSH connection to a node.
func (c *VirtualEnvironmentClient) OpenNodeShell(
ctx context.Context,
nodeName string,
) (*ssh.Client, error) {
nodeAddress, err := c.GetNodeIP(ctx, nodeName)
if err != nil {
return nil, err
}
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("failed to determine the home directory: %w", err)
}
sshHost := fmt.Sprintf("%s:22", *nodeAddress)
sshPath := path.Join(homeDir, ".ssh")
if _, err = os.Stat(sshPath); os.IsNotExist(err) {
e := os.Mkdir(sshPath, 0o700)
if e != nil {
return nil, fmt.Errorf("failed to create %s: %w", sshPath, e)
}
}
khPath := path.Join(sshPath, "known_hosts")
if _, err = os.Stat(khPath); os.IsNotExist(err) {
e := os.WriteFile(khPath, []byte{}, 0o600)
if e != nil {
return nil, fmt.Errorf("failed to create %s: %w", khPath, e)
}
}
kh, err := knownhosts.New(khPath)
if err != nil {
return nil, fmt.Errorf("failed to read %s: %w", khPath, err)
}
// Create a custom permissive hostkey callback which still errors on hosts
// with changed keys, but allows unknown hosts and adds them to known_hosts
cb := ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error {
err := kh(hostname, remote, key)
if knownhosts.IsHostKeyChanged(err) {
return fmt.Errorf("REMOTE HOST IDENTIFICATION HAS CHANGED for host %s! This may indicate a MitM attack", hostname)
}
if knownhosts.IsHostUnknown(err) {
f, ferr := os.OpenFile(khPath, os.O_APPEND|os.O_WRONLY, 0o600)
if ferr == nil {
defer CloseOrLogError(ctx)(f)
ferr = knownhosts.WriteKnownHost(f, hostname, remote, key)
}
if ferr == nil {
tflog.Info(ctx, fmt.Sprintf("Added host %s to known_hosts", hostname))
} else {
tflog.Error(ctx, fmt.Sprintf("Failed to add host %s to known_hosts", hostname), map[string]interface{}{
"error": err,
})
}
return nil
}
return err
})
sshConfig := &ssh.ClientConfig{
User: c.SSHUsername,
Auth: []ssh.AuthMethod{ssh.Password(c.SSHPassword)},
HostKeyCallback: cb,
HostKeyAlgorithms: kh.HostKeyAlgorithms(sshHost),
}
tflog.Info(ctx, fmt.Sprintf("Agent is set to %t", c.SSHAgent))
if c.SSHAgent {
sshClient, err := c.CreateSSHClientAgent(ctx, cb, kh, sshHost)
if err != nil {
tflog.Error(ctx, "Failed ssh connection through agent, "+
"falling back to password authentication",
map[string]interface{}{
"error": err,
})
} else {
return sshClient, nil
}
}
sshClient, err := ssh.Dial("tcp", sshHost, sshConfig)
if err != nil {
return nil, fmt.Errorf("failed to dial %s: %w", sshHost, err)
}
tflog.Debug(ctx, "SSH connection established", map[string]interface{}{
"host": sshHost,
"user": c.SSHUsername,
})
return sshClient, nil
}
// CreateSSHClientAgent establishes an ssh connection through the agent authentication mechanism
func (c *VirtualEnvironmentClient) CreateSSHClientAgent(
ctx context.Context,
cb ssh.HostKeyCallback,
kh knownhosts.HostKeyCallback,
sshHost string,
) (*ssh.Client, error) {
if c.SSHAgentSocket == "" {
return nil, errors.New("failed connecting to SSH agent socket: the socket file is not defined, " +
"authentication will fall back to password")
}
conn, err := net.Dial("unix", c.SSHAgentSocket)
if err != nil {
return nil, fmt.Errorf("failed connecting to SSH auth socket '%s': %w", c.SSHAgentSocket, err)
}
ag := agent.NewClient(conn)
sshConfig := &ssh.ClientConfig{
User: c.SSHUsername,
Auth: []ssh.AuthMethod{ssh.PublicKeysCallback(ag.Signers), ssh.Password(c.SSHPassword)},
HostKeyCallback: cb,
HostKeyAlgorithms: kh.HostKeyAlgorithms(sshHost),
}
sshClient, err := ssh.Dial("tcp", sshHost, sshConfig)
if err != nil {
return nil, fmt.Errorf("failed to dial %s: %w", sshHost, err)
}
tflog.Debug(ctx, "SSH connection established", map[string]interface{}{
"host": sshHost,
"user": c.SSHUsername,
})
return sshClient, nil
}
// UpdateNodeTime updates the time on a node.
func (c *VirtualEnvironmentClient) UpdateNodeTime(
ctx context.Context,
nodeName string,
d *VirtualEnvironmentNodeUpdateTimeRequestBody,
) error {
return c.DoRequest(ctx, http.MethodPut, fmt.Sprintf("nodes/%s/time", url.PathEscape(nodeName)), d, nil)
}
// WaitForNodeTask waits for a specific node task to complete.
func (c *VirtualEnvironmentClient) WaitForNodeTask(
ctx context.Context,
nodeName string,
upid string,
timeout int,
delay int,
) error {
timeDelay := int64(delay)
timeMax := float64(timeout)
timeStart := time.Now()
timeElapsed := timeStart.Sub(timeStart)
for timeElapsed.Seconds() < timeMax {
if int64(timeElapsed.Seconds())%timeDelay == 0 {
status, err := c.GetNodeTaskStatus(ctx, nodeName, upid)
if err != nil {
return err
}
if status.Status != "running" {
if status.ExitCode != "OK" {
return fmt.Errorf(
"task \"%s\" on node \"%s\" failed to complete with error: %s",
upid,
nodeName,
status.ExitCode,
)
}
return nil
}
time.Sleep(1 * time.Second)
}
time.Sleep(200 * time.Millisecond)
timeElapsed = time.Since(timeStart)
if ctx.Err() != nil {
return ctx.Err()
}
}
return fmt.Errorf(
"timeout while waiting for task \"%s\" on node \"%s\" to complete",
upid,
nodeName,
)
}