package main

import (
  "bufio"
  "io"
  "net"
  "os"
  "fmt"
  "errors"
  "time"

  "golang.org/x/crypto/ssh"
  "golang.org/x/crypto/ssh/agent"

  "github.com/pkg/sftp"

  "github.com/scaleway/scaleway-sdk-go/scw"
  "github.com/scaleway/scaleway-sdk-go/api/instance/v1"
)

const PARALLELIZE = 16
const instanceNotFound = Error("instance not found")

var msgOut chan string
var msgErr chan string

func main() {
  if len(os.Args) < 2 {
    usage()
  }

  msgOut = make(chan string, PARALLELIZE)
  msgErr = make(chan string, PARALLELIZE)
  go logger(os.Stdout, msgOut)
  go logger(os.Stderr, msgErr)

  var err error
  switch os.Args[1] {
    case "spawn": 
      err = spawn()
    case "run": 
      err = run()
    case "destroy": 
      err = destroy()
  }

  if err != nil {
    msgErr <- fmt.Sprintf("cmd failed: %v\n", err)
    os.Exit(1)
  }
}

func logger(o *os.File, c chan string) {
  for m := range c {
    fmt.Fprint(o, m)
  }
}

func usage() {
  var programName = "swtool"
  if len(os.Args) > 0 {
    programName = os.Args[0]
  }
  msgErr <- fmt.Sprintf("Usage: %v (spawn|run|destroy)\n", programName)
  os.Exit(1)
}

/**
 * Errors
 */
type Error string

func (e Error) Error() string {
  return string(e)
}

/**
 * Parser logic
 */

type instanceReceiver interface {
  onInstance(zone, machine, image, name string) error
}

func passInstanceTo(r io.Reader, d instanceReceiver) error {
  com := make(chan error, PARALLELIZE)

  count := 0
  failed := 0

  for {
    var zone, machine, image, name string
    n, err := fmt.Fscanf(r, "%s %s %s %s", &zone, &machine, &image, &name)
    if err == io.EOF || n == 0 { 
      break 
    } else if err != nil {
      return err
    } else if n != 4 {
      return errors.New(fmt.Sprintf("Wrong number of values (got %d, expected 4)\n", n))
    }

    go func(zone, machine, image, name string) {
      err := d.onInstance(zone, machine, image, name)
      if err != nil {
        msgErr <- fmt.Sprintf("❌ Operation failed for %v (%v, %v, %v): %v\n", name, zone, machine, image, err)
      }
      com <- err
    }(zone, machine, image, name)

    count += 1
  }

  msgOut <- fmt.Sprintf("ℹī¸  Waiting for %v servers\n", count)
  for count > 0 {
    err := <- com
    count -= 1
    if err != nil {
      failed += 1
    }
  }

  if failed > 0 {
    return errors.New(fmt.Sprintf("%d operations failed", failed))
  }
  return nil
}

/**
 * instance wrapper
 */

type action struct {
  api *instance.API
}
func (i *action) init() error {
  client, err := getClient()
  if err != nil {
    return err
  }

  i.api = instance.NewAPI(client)
  return nil
}
func (i *action) getInstanceByName(zone scw.Zone, name string) (*instance.Server, error) {
  lr, err := i.api.ListServers(&instance.ListServersRequest{
    Zone: zone,
    Name: &name,
  })
  if err != nil {
    return nil, err
  }
  for _, s := range lr.Servers {
    if s.Name == name {
     return s, nil
    }
  }
  return nil, instanceNotFound
}

func parseIP(s *instance.Server) string {
  ip := "(no address)"
  if s.PublicIP != nil {
    ip = s.PublicIP.Address.String()
    if s.PublicIP.Dynamic {
      ip += " (dynamic)"
    }
  } else if s.PrivateIP != nil {
    ip = fmt.Sprintf("%s %s", *s.PrivateIP, "(private)")
  }
  return ip
}


/**
 * Spawner
 */
type spawner struct { action }
func (sp *spawner) onInstance(zone, machine, image, name string) error {
  z, err := scw.ParseZone(zone)
  if err != nil {
    return err
  }

  targetServer, err := sp.getInstanceByName(z, name)

  if err == nil {
    ip := parseIP(targetServer)
    if targetServer.State == instance.ServerStateRunning {
      msgOut <- fmt.Sprintf("đŸŸŖ Found %v on zone %v with ip %v\n", targetServer.Name, targetServer.Zone, ip)
      return nil
    }
  } else if err == instanceNotFound {
    // Create a new server
    createRes, err := sp.api.CreateServer(&instance.CreateServerRequest{
      Zone:               z,
      Name:               name,
      CommercialType:     machine,
      Image:              image,
      DynamicIPRequired:  scw.BoolPtr(true),
    })
    if err != nil {
      return err
    }
    targetServer = createRes.Server
  } else {
    return err
  }

  timeout := 5 * time.Minute
  err = sp.api.ServerActionAndWait(&instance.ServerActionAndWaitRequest{
    ServerID: targetServer.ID,
    Action:   instance.ServerActionPoweron,
    Zone:     z,
    Timeout:  &timeout,
  })
  if err != nil {
    return err
  }

  targetServer, err = sp.getInstanceByName(z, name)
  if err != nil {
    return err
  }

  ip := parseIP(targetServer)
  msgOut <- fmt.Sprintf("✅ Started %v on zone %v with ip %v\n", targetServer.Name, targetServer.Zone, ip)
  return nil
}

/**
 * Destroyer
 */
type destroyer struct { action }

func (dt *destroyer) onInstance(zone, machine, image, name string) error {
  z, err := scw.ParseZone(zone)
  if err != nil {
    return err
  }

  targetServer, err := dt.getInstanceByName(z, name)
  if err == instanceNotFound {
    msgOut <- fmt.Sprintf("đŸŸŖ %v is already destroyed\n", name)
    return nil
  } else if err != nil {
    return err
  }

  err = dt.api.ServerActionAndWait(&instance.ServerActionAndWaitRequest{
    Zone:     z,
    ServerID: targetServer.ID,
    Action:   instance.ServerActionPoweroff,
  })
  if err != nil {
    return err
  }

  err = dt.api.DeleteServer(&instance.DeleteServerRequest{
    Zone:     z,
    ServerID: targetServer.ID,
  })
  if err != nil {
    return err
  }

  ip := parseIP(targetServer)
  msgOut <- fmt.Sprintf("✅ Destroyed %v on zone %v with ip %v\n", targetServer.Name, targetServer.Zone, ip)
  return nil
}

/**
 * Runner
 */

type runner struct { action }

func (r *runner) connect(zone, name string) (*ssh.Client, error) {
  // Connect to the remote
  z, err := scw.ParseZone(zone)
  if err != nil {
    return nil, err
  }

  targetServer, err := r.getInstanceByName(z, name)
  if err != nil {
    return nil, err
  }

  if targetServer.PublicIP == nil {
    return nil, errors.New("run failed: this instance has no public ip.")
  }
  ip := targetServer.PublicIP.Address

  socket := os.Getenv("SSH_AUTH_SOCK")
	conn, err := net.Dial("unix", socket)
	if err != nil {
    return nil, err
	}

	agentClient := agent.NewClient(conn)
	config := &ssh.ClientConfig{
		User: "root",
		Auth: []ssh.AuthMethod{
			ssh.PublicKeysCallback(agentClient.Signers),
		},
		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
	}

  return ssh.Dial("tcp", ip.String()+":22", config)
}

func (r *runner) send(sshClient *ssh.Client) error {
  // Source script
  if len(os.Args) < 3 {
    return errors.New("missing script to run on the command line")
  }

  src, err := os.Open(os.Args[2])
  if err != nil {
    return err
  }
  defer src.Close()

  // Send our script to the destination
  sftpClient, err := sftp.NewClient(sshClient)
  if err != nil {
    return err
  }
  defer sftpClient.Close()

  dst, err := sftpClient.Create("/tmp/nuage")
  if err != nil {
    return err
  }
  defer dst.Close()

  err = dst.Chmod(0755)
  if err != nil {
    return err
  }

  _, err = io.Copy(dst, src)
  return err
}

func readerToChan(r io.Reader, c chan string) {
  br := bufio.NewReader(r)
  for {
    s, err := br.ReadString('\n')
    if err != nil {
      break
    }
    c <- s
  }
}
func (r *runner) exec(sshClient *ssh.Client) error {
  // Run the script
	session, err := sshClient.NewSession()
	if err != nil {
    return err
	}
	defer session.Close()

  if os.Getenv("VERBOSE") != "" {
    readErr, err := session.StderrPipe()
    if err != nil {
      return err
    }
    readerToChan(readErr, msgErr)

    readOut, err := session.StdoutPipe()
    if err != nil {
      return err
    }
    readerToChan(readOut, msgOut)
  }

  err = session.Run("/tmp/nuage")
  if err != nil {
    return err
  }

  return nil
}

func (r *runner) onInstance(zone, machine, image, name string) error {
  sshClient, err := r.connect(zone, name)
  if err != nil {
    return err
  }
	defer sshClient.Close()

  err = r.send(sshClient)
  if err != nil {
    return err
  }

  err = r.exec(sshClient)
  if err != nil {
    return err
  }

  msgOut <- fmt.Sprintf("✅ Successfully ran the script on %v (zone %v)\n", name, zone)
  return nil
}


/**
 * Commands
 */

func spawn() error {
  sp := spawner{}
  err := sp.init()
  if err != nil {
    return err
  }

  err = passInstanceTo(os.Stdin, &sp)
  return err
}

func run() error {
  r := runner{}
  err := r.init()
  if err != nil {
    return err
  }

  err = passInstanceTo(os.Stdin, &r)
  return err
}

func destroy() error {
  dt := destroyer{}
  err := dt.init()
  if err != nil {
    return err
  }

  err = passInstanceTo(os.Stdin, &dt)
  return err
}


func getClient() (*scw.Client, error) {
  // Get config
  // Install scw: https://github.com/scaleway/scaleway-cli
  // And run `scw init` to create the config file
  config, err := scw.LoadConfig()
  if err != nil {
    return nil, err
  }

  // Use active profile
  profile, err := config.GetActiveProfile()
  if err != nil {
    return nil, err
  }

  // We use the default profile
  client, err := scw.NewClient(
    scw.WithProfile(profile),
    scw.WithEnv(), // env variable may overwrite profile values
  )
  return client, err
}