467 lines
9 KiB
Go
467 lines
9 KiB
Go
package main
|
||
|
||
import (
|
||
"bufio"
|
||
"io"
|
||
"net"
|
||
"os"
|
||
"fmt"
|
||
"errors"
|
||
"time"
|
||
|
||
"golang.org/x/crypto/ssh"
|
||
"golang.org/x/crypto/ssh/agent"
|
||
|
||
"github.com/pkg/sftp"
|
||
|
||
"github.com/scaleway/scaleway-sdk-go/scw"
|
||
"github.com/scaleway/scaleway-sdk-go/api/instance/v1"
|
||
)
|
||
|
||
const PARALLELIZE = 16
|
||
const instanceNotFound = Error("instance not found")
|
||
|
||
var msgOut chan string
|
||
var msgErr chan string
|
||
|
||
func main() {
|
||
if len(os.Args) < 2 {
|
||
usage()
|
||
}
|
||
|
||
msgOut = make(chan string, PARALLELIZE)
|
||
msgErr = make(chan string, PARALLELIZE)
|
||
go logger(os.Stdout, msgOut)
|
||
go logger(os.Stderr, msgErr)
|
||
|
||
var err error
|
||
switch os.Args[1] {
|
||
case "spawn":
|
||
err = spawn()
|
||
case "run":
|
||
err = run()
|
||
case "destroy":
|
||
err = destroy()
|
||
}
|
||
|
||
if err != nil {
|
||
msgErr <- fmt.Sprintf("cmd failed: %v\n", err)
|
||
os.Exit(1)
|
||
}
|
||
}
|
||
|
||
func logger(o *os.File, c chan string) {
|
||
for m := range c {
|
||
fmt.Fprint(o, m)
|
||
}
|
||
}
|
||
|
||
func usage() {
|
||
var programName = "swtool"
|
||
if len(os.Args) > 0 {
|
||
programName = os.Args[0]
|
||
}
|
||
msgErr <- fmt.Sprintf("Usage: %v (spawn|run|destroy)\n", programName)
|
||
os.Exit(1)
|
||
}
|
||
|
||
/**
|
||
* Errors
|
||
*/
|
||
type Error string
|
||
|
||
func (e Error) Error() string {
|
||
return string(e)
|
||
}
|
||
|
||
/**
|
||
* Parser logic
|
||
*/
|
||
|
||
type instanceReceiver interface {
|
||
onInstance(zone, machine, image, name string) error
|
||
}
|
||
|
||
func passInstanceTo(r io.Reader, d instanceReceiver) error {
|
||
com := make(chan error, PARALLELIZE)
|
||
|
||
count := 0
|
||
failed := 0
|
||
|
||
for {
|
||
var zone, machine, image, name string
|
||
n, err := fmt.Fscanf(r, "%s %s %s %s", &zone, &machine, &image, &name)
|
||
if err == io.EOF || n == 0 {
|
||
break
|
||
} else if err != nil {
|
||
return err
|
||
} else if n != 4 {
|
||
return errors.New(fmt.Sprintf("Wrong number of values (got %d, expected 4)\n", n))
|
||
}
|
||
|
||
go func(zone, machine, image, name string) {
|
||
err := d.onInstance(zone, machine, image, name)
|
||
if err != nil {
|
||
msgErr <- fmt.Sprintf("❌ Operation failed for %v (%v, %v, %v): %v\n", name, zone, machine, image, err)
|
||
}
|
||
com <- err
|
||
}(zone, machine, image, name)
|
||
|
||
count += 1
|
||
}
|
||
|
||
msgOut <- fmt.Sprintf("ℹ️ Waiting for %v servers\n", count)
|
||
for count > 0 {
|
||
err := <- com
|
||
count -= 1
|
||
if err != nil {
|
||
failed += 1
|
||
}
|
||
}
|
||
|
||
if failed > 0 {
|
||
return errors.New(fmt.Sprintf("%d operations failed", failed))
|
||
}
|
||
return nil
|
||
}
|
||
|
||
/**
|
||
* instance wrapper
|
||
*/
|
||
|
||
type action struct {
|
||
api *instance.API
|
||
}
|
||
func (i *action) init() error {
|
||
client, err := getClient()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
i.api = instance.NewAPI(client)
|
||
return nil
|
||
}
|
||
func (i *action) getInstanceByName(zone scw.Zone, name string) (*instance.Server, error) {
|
||
lr, err := i.api.ListServers(&instance.ListServersRequest{
|
||
Zone: zone,
|
||
Name: &name,
|
||
})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, s := range lr.Servers {
|
||
if s.Name == name {
|
||
return s, nil
|
||
}
|
||
}
|
||
return nil, instanceNotFound
|
||
}
|
||
|
||
func parseIP(s *instance.Server) string {
|
||
ip := "(no address)"
|
||
if s.PublicIP != nil {
|
||
ip = s.PublicIP.Address.String()
|
||
if s.PublicIP.Dynamic {
|
||
ip += " (dynamic)"
|
||
}
|
||
} else if s.PrivateIP != nil {
|
||
ip = fmt.Sprintf("%s %s", *s.PrivateIP, "(private)")
|
||
}
|
||
return ip
|
||
}
|
||
|
||
|
||
/**
|
||
* Spawner
|
||
*/
|
||
type spawner struct { action }
|
||
func (sp *spawner) onInstance(zone, machine, image, name string) error {
|
||
z, err := scw.ParseZone(zone)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
targetServer, err := sp.getInstanceByName(z, name)
|
||
|
||
if err == nil {
|
||
ip := parseIP(targetServer)
|
||
if targetServer.State == instance.ServerStateRunning {
|
||
msgOut <- fmt.Sprintf("🟣 Found %v on zone %v with ip %v\n", targetServer.Name, targetServer.Zone, ip)
|
||
return nil
|
||
}
|
||
} else if err == instanceNotFound {
|
||
// Create a new server
|
||
createRes, err := sp.api.CreateServer(&instance.CreateServerRequest{
|
||
Zone: z,
|
||
Name: name,
|
||
CommercialType: machine,
|
||
Image: image,
|
||
DynamicIPRequired: scw.BoolPtr(true),
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
targetServer = createRes.Server
|
||
} else {
|
||
return err
|
||
}
|
||
|
||
timeout := 5 * time.Minute
|
||
err = sp.api.ServerActionAndWait(&instance.ServerActionAndWaitRequest{
|
||
ServerID: targetServer.ID,
|
||
Action: instance.ServerActionPoweron,
|
||
Zone: z,
|
||
Timeout: &timeout,
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
targetServer, err = sp.getInstanceByName(z, name)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
ip := parseIP(targetServer)
|
||
msgOut <- fmt.Sprintf("✅ Started %v on zone %v with ip %v\n", targetServer.Name, targetServer.Zone, ip)
|
||
return nil
|
||
}
|
||
|
||
/**
|
||
* Destroyer
|
||
*/
|
||
type destroyer struct { action }
|
||
|
||
func (dt *destroyer) onInstance(zone, machine, image, name string) error {
|
||
z, err := scw.ParseZone(zone)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
targetServer, err := dt.getInstanceByName(z, name)
|
||
if err == instanceNotFound {
|
||
msgOut <- fmt.Sprintf("🟣 %v is already destroyed\n", name)
|
||
return nil
|
||
} else if err != nil {
|
||
return err
|
||
}
|
||
|
||
err = dt.api.ServerActionAndWait(&instance.ServerActionAndWaitRequest{
|
||
Zone: z,
|
||
ServerID: targetServer.ID,
|
||
Action: instance.ServerActionPoweroff,
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
err = dt.api.DeleteServer(&instance.DeleteServerRequest{
|
||
Zone: z,
|
||
ServerID: targetServer.ID,
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
ip := parseIP(targetServer)
|
||
msgOut <- fmt.Sprintf("✅ Destroyed %v on zone %v with ip %v\n", targetServer.Name, targetServer.Zone, ip)
|
||
return nil
|
||
}
|
||
|
||
/**
|
||
* Runner
|
||
*/
|
||
|
||
type runner struct { action }
|
||
|
||
func (r *runner) connect(zone, name string) (*ssh.Client, error) {
|
||
// Connect to the remote
|
||
z, err := scw.ParseZone(zone)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
targetServer, err := r.getInstanceByName(z, name)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if targetServer.PublicIP == nil {
|
||
return nil, errors.New("run failed: this instance has no public ip.")
|
||
}
|
||
ip := targetServer.PublicIP.Address
|
||
|
||
socket := os.Getenv("SSH_AUTH_SOCK")
|
||
conn, err := net.Dial("unix", socket)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
agentClient := agent.NewClient(conn)
|
||
config := &ssh.ClientConfig{
|
||
User: "root",
|
||
Auth: []ssh.AuthMethod{
|
||
ssh.PublicKeysCallback(agentClient.Signers),
|
||
},
|
||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||
}
|
||
|
||
return ssh.Dial("tcp", ip.String()+":22", config)
|
||
}
|
||
|
||
func (r *runner) send(sshClient *ssh.Client) error {
|
||
// Source script
|
||
if len(os.Args) < 3 {
|
||
return errors.New("missing script to run on the command line")
|
||
}
|
||
|
||
src, err := os.Open(os.Args[2])
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer src.Close()
|
||
|
||
// Send our script to the destination
|
||
sftpClient, err := sftp.NewClient(sshClient)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer sftpClient.Close()
|
||
|
||
dst, err := sftpClient.Create("/tmp/nuage")
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer dst.Close()
|
||
|
||
err = dst.Chmod(0755)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
_, err = io.Copy(dst, src)
|
||
return err
|
||
}
|
||
|
||
func readerToChan(r io.Reader, c chan string) {
|
||
br := bufio.NewReader(r)
|
||
for {
|
||
s, err := br.ReadString('\n')
|
||
if err != nil {
|
||
break
|
||
}
|
||
c <- s
|
||
}
|
||
}
|
||
func (r *runner) exec(sshClient *ssh.Client) error {
|
||
// Run the script
|
||
session, err := sshClient.NewSession()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer session.Close()
|
||
|
||
if os.Getenv("VERBOSE") != "" {
|
||
readErr, err := session.StderrPipe()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
readerToChan(readErr, msgErr)
|
||
|
||
readOut, err := session.StdoutPipe()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
readerToChan(readOut, msgOut)
|
||
}
|
||
|
||
err = session.Run("/tmp/nuage")
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *runner) onInstance(zone, machine, image, name string) error {
|
||
sshClient, err := r.connect(zone, name)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer sshClient.Close()
|
||
|
||
err = r.send(sshClient)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
err = r.exec(sshClient)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
msgOut <- fmt.Sprintf("✅ Successfully ran the script on %v (zone %v)\n", name, zone)
|
||
return nil
|
||
}
|
||
|
||
|
||
/**
|
||
* Commands
|
||
*/
|
||
|
||
func spawn() error {
|
||
sp := spawner{}
|
||
err := sp.init()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
err = passInstanceTo(os.Stdin, &sp)
|
||
return err
|
||
}
|
||
|
||
func run() error {
|
||
r := runner{}
|
||
err := r.init()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
err = passInstanceTo(os.Stdin, &r)
|
||
return err
|
||
}
|
||
|
||
func destroy() error {
|
||
dt := destroyer{}
|
||
err := dt.init()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
err = passInstanceTo(os.Stdin, &dt)
|
||
return err
|
||
}
|
||
|
||
|
||
func getClient() (*scw.Client, error) {
|
||
// Get config
|
||
// Install scw: https://github.com/scaleway/scaleway-cli
|
||
// And run `scw init` to create the config file
|
||
config, err := scw.LoadConfig()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// Use active profile
|
||
profile, err := config.GetActiveProfile()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// We use the default profile
|
||
client, err := scw.NewClient(
|
||
scw.WithProfile(profile),
|
||
scw.WithEnv(), // env variable may overwrite profile values
|
||
)
|
||
return client, err
|
||
}
|