286 lines
6.4 KiB
Go
286 lines
6.4 KiB
Go
|
package executor
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"os"
|
||
|
"os/exec"
|
||
|
"sync"
|
||
|
"syscall"
|
||
|
|
||
|
hclog "github.com/hashicorp/go-hclog"
|
||
|
"github.com/hashicorp/nomad/plugins/drivers"
|
||
|
dproto "github.com/hashicorp/nomad/plugins/drivers/proto"
|
||
|
)
|
||
|
|
||
|
// execHelper is a convenient wrapper for starting and executing commands, and handling their output
|
||
|
type execHelper struct {
|
||
|
logger hclog.Logger
|
||
|
|
||
|
// newTerminal function creates a tty appropriate for the command
|
||
|
// The returned pty end of tty function is to be called after process start.
|
||
|
newTerminal func() (pty func() (*os.File, error), tty *os.File, err error)
|
||
|
|
||
|
// setTTY is a callback to configure the command with slave end of the tty of the terminal, when tty is enabled
|
||
|
setTTY func(tty *os.File) error
|
||
|
|
||
|
// setTTY is a callback to configure the command with std{in|out|err}, when tty is disabled
|
||
|
setIO func(stdin io.Reader, stdout, stderr io.Writer) error
|
||
|
|
||
|
// processStart starts the process, like `exec.Cmd.Start()`
|
||
|
processStart func() error
|
||
|
|
||
|
// processWait blocks until command terminates and returns its final state
|
||
|
processWait func() (*os.ProcessState, error)
|
||
|
}
|
||
|
|
||
|
func (e *execHelper) run(ctx context.Context, tty bool, stream drivers.ExecTaskStream) error {
|
||
|
if tty {
|
||
|
return e.runTTY(ctx, stream)
|
||
|
}
|
||
|
return e.runNoTTY(ctx, stream)
|
||
|
}
|
||
|
|
||
|
func (e *execHelper) runTTY(ctx context.Context, stream drivers.ExecTaskStream) error {
|
||
|
ptyF, tty, err := e.newTerminal()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to open a tty: %v", err)
|
||
|
}
|
||
|
defer tty.Close()
|
||
|
|
||
|
if err := e.setTTY(tty); err != nil {
|
||
|
return fmt.Errorf("failed to set command tty: %v", err)
|
||
|
}
|
||
|
if err := e.processStart(); err != nil {
|
||
|
return fmt.Errorf("failed to start command: %v", err)
|
||
|
}
|
||
|
|
||
|
var wg sync.WaitGroup
|
||
|
errCh := make(chan error, 3)
|
||
|
|
||
|
pty, err := ptyF()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to get pty: %v", err)
|
||
|
}
|
||
|
|
||
|
defer pty.Close()
|
||
|
wg.Add(1)
|
||
|
go handleStdin(e.logger, pty, stream, errCh)
|
||
|
// when tty is on, stdout and stderr point to the same pty so only read once
|
||
|
go handleStdout(e.logger, pty, &wg, stream.Send, errCh)
|
||
|
|
||
|
ps, err := e.processWait()
|
||
|
|
||
|
// force close streams to close out the stream copying goroutines
|
||
|
tty.Close()
|
||
|
|
||
|
// wait until we get all process output
|
||
|
wg.Wait()
|
||
|
|
||
|
// wait to flush out output
|
||
|
stream.Send(cmdExitResult(ps, err))
|
||
|
|
||
|
select {
|
||
|
case cerr := <-errCh:
|
||
|
return cerr
|
||
|
default:
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (e *execHelper) runNoTTY(ctx context.Context, stream drivers.ExecTaskStream) error {
|
||
|
var sendLock sync.Mutex
|
||
|
send := func(v *drivers.ExecTaskStreamingResponseMsg) error {
|
||
|
sendLock.Lock()
|
||
|
defer sendLock.Unlock()
|
||
|
|
||
|
return stream.Send(v)
|
||
|
}
|
||
|
|
||
|
stdinPr, stdinPw := io.Pipe()
|
||
|
stdoutPr, stdoutPw := io.Pipe()
|
||
|
stderrPr, stderrPw := io.Pipe()
|
||
|
|
||
|
defer stdoutPw.Close()
|
||
|
defer stderrPw.Close()
|
||
|
|
||
|
if err := e.setIO(stdinPr, stdoutPw, stderrPw); err != nil {
|
||
|
return fmt.Errorf("failed to set command io: %v", err)
|
||
|
}
|
||
|
|
||
|
if err := e.processStart(); err != nil {
|
||
|
return fmt.Errorf("failed to start command: %v", err)
|
||
|
}
|
||
|
|
||
|
var wg sync.WaitGroup
|
||
|
errCh := make(chan error, 3)
|
||
|
|
||
|
wg.Add(2)
|
||
|
go handleStdin(e.logger, stdinPw, stream, errCh)
|
||
|
go handleStdout(e.logger, stdoutPr, &wg, send, errCh)
|
||
|
go handleStderr(e.logger, stderrPr, &wg, send, errCh)
|
||
|
|
||
|
ps, err := e.processWait()
|
||
|
|
||
|
// force close streams to close out the stream copying goroutines
|
||
|
stdinPr.Close()
|
||
|
stdoutPw.Close()
|
||
|
stderrPw.Close()
|
||
|
|
||
|
// wait until we get all process output
|
||
|
wg.Wait()
|
||
|
|
||
|
// wait to flush out output
|
||
|
stream.Send(cmdExitResult(ps, err))
|
||
|
|
||
|
select {
|
||
|
case cerr := <-errCh:
|
||
|
return cerr
|
||
|
default:
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
func cmdExitResult(ps *os.ProcessState, err error) *drivers.ExecTaskStreamingResponseMsg {
|
||
|
exitCode := -1
|
||
|
|
||
|
if ps == nil {
|
||
|
if ee, ok := err.(*exec.ExitError); ok {
|
||
|
ps = ee.ProcessState
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if ps == nil {
|
||
|
exitCode = -2
|
||
|
} else if status, ok := ps.Sys().(syscall.WaitStatus); ok {
|
||
|
exitCode = status.ExitStatus()
|
||
|
if status.Signaled() {
|
||
|
const exitSignalBase = 128
|
||
|
signal := int(status.Signal())
|
||
|
exitCode = exitSignalBase + signal
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return &drivers.ExecTaskStreamingResponseMsg{
|
||
|
Exited: true,
|
||
|
Result: &dproto.ExitResult{
|
||
|
ExitCode: int32(exitCode),
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func handleStdin(logger hclog.Logger, stdin io.WriteCloser, stream drivers.ExecTaskStream, errCh chan<- error) {
|
||
|
for {
|
||
|
m, err := stream.Recv()
|
||
|
if isClosedError(err) {
|
||
|
return
|
||
|
} else if err != nil {
|
||
|
errCh <- err
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if m.Stdin != nil {
|
||
|
if len(m.Stdin.Data) != 0 {
|
||
|
_, err := stdin.Write(m.Stdin.Data)
|
||
|
if err != nil {
|
||
|
errCh <- err
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
if m.Stdin.Close {
|
||
|
stdin.Close()
|
||
|
}
|
||
|
} else if m.TtySize != nil {
|
||
|
err := setTTYSize(stdin, m.TtySize.Height, m.TtySize.Width)
|
||
|
if err != nil {
|
||
|
errCh <- fmt.Errorf("failed to resize tty: %v", err)
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func handleStdout(logger hclog.Logger, reader io.Reader, wg *sync.WaitGroup, send func(*drivers.ExecTaskStreamingResponseMsg) error, errCh chan<- error) {
|
||
|
defer wg.Done()
|
||
|
|
||
|
buf := make([]byte, 4096)
|
||
|
for {
|
||
|
n, err := reader.Read(buf)
|
||
|
// always send output first if we read something
|
||
|
if n > 0 {
|
||
|
if err := send(&drivers.ExecTaskStreamingResponseMsg{
|
||
|
Stdout: &dproto.ExecTaskStreamingIOOperation{
|
||
|
Data: buf[:n],
|
||
|
},
|
||
|
}); err != nil {
|
||
|
errCh <- err
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// then process error
|
||
|
if isClosedError(err) {
|
||
|
if err := send(&drivers.ExecTaskStreamingResponseMsg{
|
||
|
Stdout: &dproto.ExecTaskStreamingIOOperation{
|
||
|
Close: true,
|
||
|
},
|
||
|
}); err != nil {
|
||
|
errCh <- err
|
||
|
return
|
||
|
}
|
||
|
return
|
||
|
} else if err != nil {
|
||
|
errCh <- err
|
||
|
return
|
||
|
}
|
||
|
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func handleStderr(logger hclog.Logger, reader io.Reader, wg *sync.WaitGroup, send func(*drivers.ExecTaskStreamingResponseMsg) error, errCh chan<- error) {
|
||
|
defer wg.Done()
|
||
|
|
||
|
buf := make([]byte, 4096)
|
||
|
for {
|
||
|
n, err := reader.Read(buf)
|
||
|
// always send output first if we read something
|
||
|
if n > 0 {
|
||
|
if err := send(&drivers.ExecTaskStreamingResponseMsg{
|
||
|
Stderr: &dproto.ExecTaskStreamingIOOperation{
|
||
|
Data: buf[:n],
|
||
|
},
|
||
|
}); err != nil {
|
||
|
errCh <- err
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// then process error
|
||
|
if isClosedError(err) {
|
||
|
if err := send(&drivers.ExecTaskStreamingResponseMsg{
|
||
|
Stderr: &dproto.ExecTaskStreamingIOOperation{
|
||
|
Close: true,
|
||
|
},
|
||
|
}); err != nil {
|
||
|
errCh <- err
|
||
|
return
|
||
|
}
|
||
|
return
|
||
|
} else if err != nil {
|
||
|
errCh <- err
|
||
|
return
|
||
|
}
|
||
|
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func isClosedError(err error) bool {
|
||
|
if err == nil {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
return err == io.EOF ||
|
||
|
err == io.ErrClosedPipe ||
|
||
|
isUnixEIOErr(err)
|
||
|
}
|