Handle shutdown better

This commit is contained in:
Alex 2020-03-02 21:51:13 +01:00
parent 7e0dbc5824
commit df1f2d38b3
5 changed files with 69 additions and 29 deletions

View file

@ -130,6 +130,7 @@ func RemoveAccount(mxUser string, name string) {
func CloseAllAcountsForShutdown() { func CloseAllAcountsForShutdown() {
accountsLock.Lock() accountsLock.Lock()
defer accountsLock.Unlock()
for _, accl := range registeredAccounts { for _, accl := range registeredAccounts {
for _, acct := range accl { for _, acct := range accl {
acct.Conn.Close() acct.Conn.Close()

View file

@ -79,8 +79,8 @@ type External struct {
config Configuration config Configuration
recv io.Reader recvPipe io.ReadCloser
send io.Writer sendPipe io.WriteCloser
sendJson *json.Encoder sendJson *json.Encoder
generation int generation int
@ -114,7 +114,7 @@ func (ext *External) Configure(c Configuration) error {
ext.handlerChan = make(chan *extMessageWithData, 1000) ext.handlerChan = make(chan *extMessageWithData, 1000)
go ext.handlerLoop(ext.generation) go ext.handlerLoop(ext.generation)
err = ext.setupProc() err = ext.setupProc(ext.generation)
if err != nil { if err != nil {
return err return err
} }
@ -133,26 +133,28 @@ func (ext *External) Configure(c Configuration) error {
// ---- Process management and communication logic // ---- Process management and communication logic
func (ext *External) setupProc() error { func (ext *External) setupProc(generation int) error {
var err error var err error
ext.proc = exec.Command(ext.command) ext.proc = exec.Command(ext.command)
ext.recv, err = ext.proc.StdoutPipe() ext.recvPipe, err = ext.proc.StdoutPipe()
if err != nil { if err != nil {
return err return err
} }
ext.send, err = ext.proc.StdinPipe() ext.sendPipe, err = ext.proc.StdinPipe()
if err != nil { if err != nil {
return err return err
} }
send := io.Writer(ext.sendPipe)
recv := io.Reader(ext.recvPipe)
if ext.debug { if ext.debug {
ext.recv = io.TeeReader(ext.recv, os.Stderr) recv = io.TeeReader(recv, os.Stderr)
ext.send = io.MultiWriter(ext.send, os.Stderr) send = io.MultiWriter(send, os.Stderr)
} }
ext.sendJson = json.NewEncoder(ext.send) ext.sendJson = json.NewEncoder(send)
ext.proc.Stderr = os.Stderr ext.proc.Stderr = os.Stderr
@ -161,7 +163,7 @@ func (ext *External) setupProc() error {
return err return err
} }
go ext.recvLoop() go ext.recvLoop(recv, generation)
return nil return nil
} }
@ -175,7 +177,7 @@ func (ext *External) restartLoop(generation int) {
break break
} }
log.Printf("Process %s stopped, restarting.", ext.command) log.Printf("Process %s stopped, restarting.", ext.command)
err := ext.setupProc() err := ext.setupProc(generation)
if err != nil { if err != nil {
ext.proc = nil ext.proc = nil
log.Warnf("Unable to restart %s: %s", ext.command, err) log.Warnf("Unable to restart %s: %s", ext.command, err)
@ -232,8 +234,8 @@ func (m *extMessageWithData) UnmarshalJSON(jj []byte) error {
} }
func (ext *External) recvLoop() { func (ext *External) recvLoop(from io.Reader, generation int) {
scanner := bufio.NewScanner(ext.recv) scanner := bufio.NewScanner(from)
for scanner.Scan() { for scanner.Scan() {
var msg extMessageWithData var msg extMessageWithData
err := json.Unmarshal(scanner.Bytes(), &msg) err := json.Unmarshal(scanner.Bytes(), &msg)
@ -260,6 +262,10 @@ func (ext *External) recvLoop() {
} else { } else {
ext.handlerChan <- &msg ext.handlerChan <- &msg
} }
if ext.generation != generation {
break
}
} }
} }
@ -315,20 +321,25 @@ func (ext *External) cmd(msg extMessage, data interface{}) (*extMessageWithData,
} }
func (ext *External) Close() { func (ext *External) Close() {
ext.generation += 1
ext.sendJson.Encode(&extMessage{ ext.sendJson.Encode(&extMessage{
MsgType: CLOSE, MsgType: CLOSE,
}) })
ext.generation += 1
proc := ext.proc proc := ext.proc
proc.Process.Signal(os.Interrupt)
ext.recvPipe.Close()
ext.sendPipe.Close()
ext.proc = nil ext.proc = nil
ext.recv = nil ext.recvPipe = nil
ext.send = nil ext.sendPipe = nil
ext.sendJson = nil ext.sendJson = nil
ext.handlerChan = nil ext.handlerChan = nil
go func() { go func() {
time.Sleep(10 * time.Second) time.Sleep(1 * time.Second)
proc.Process.Kill() proc.Process.Kill()
}() }()
} }

22
main.go
View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
@ -9,6 +10,7 @@ import (
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
@ -179,12 +181,12 @@ func main() {
sigch := make(chan os.Signal) sigch := make(chan os.Signal)
signal.Notify(sigch, os.Interrupt, syscall.SIGTERM) signal.Notify(sigch, os.Interrupt, syscall.SIGTERM)
err = StartAppService(errch) as_server, err := StartAppService(errch)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
StartWeb(errch) web_server := StartWeb(errch)
// Wait for an error somewhere or interrupt signal // Wait for an error somewhere or interrupt signal
select { select {
@ -196,6 +198,22 @@ func main() {
log.Warnf("Got signal %s", sig.String()) log.Warnf("Got signal %s", sig.String())
} }
// Shut down, hopefully this is not a too bad way to do it
log.Warn("Shuttind down") log.Warn("Shuttind down")
delay := 2 * time.Second
ctx1, _ := context.WithTimeout(context.TODO(), delay)
go as_server.Shutdown(ctx1)
ctx2, _ := context.WithTimeout(context.TODO(), delay)
go web_server.Shutdown(ctx2)
time.Sleep(delay)
CloseAllAcountsForShutdown() CloseAllAcountsForShutdown()
if err != nil {
os.Exit(1)
} else {
os.Exit(0)
}
} }

View file

@ -16,18 +16,18 @@ import (
var mx *mxlib.Client var mx *mxlib.Client
func StartAppService(errch chan error) error { func StartAppService(errch chan error) (*http.Server, error) {
mx = mxlib.NewClient(config.Server, registration.AsToken) mx = mxlib.NewClient(config.Server, registration.AsToken)
err := InitDb() err := InitDb()
if err != nil { if err != nil {
return err return nil, err
} }
if dbKvGet("ezbr_initialized") != "yes" { if dbKvGet("ezbr_initialized") != "yes" {
err = mx.RegisterUser(registration.SenderLocalpart) err = mx.RegisterUser(registration.SenderLocalpart)
if mxe, ok := err.(*mxlib.MxError); !ok || mxe.ErrCode != "M_USER_IN_USE" { if mxe, ok := err.(*mxlib.MxError); !ok || mxe.ErrCode != "M_USER_IN_USE" {
return err return nil, err
} }
_, st := os.Stat(config.AvatarFile) _, st := os.Stat(config.AvatarFile)
@ -36,13 +36,13 @@ func StartAppService(errch chan error) error {
Path: config.AvatarFile, Path: config.AvatarFile,
}) })
if err != nil { if err != nil {
return err return nil, err
} }
} }
err = mx.ProfileDisplayname(ezbrMxId(), fmt.Sprintf("Easybridge (%s)", EASYBRIDGE_SYSTEM_PROTOCOL)) err = mx.ProfileDisplayname(ezbrMxId(), fmt.Sprintf("Easybridge (%s)", EASYBRIDGE_SYSTEM_PROTOCOL))
if err != nil { if err != nil {
return err return nil, err
} }
dbKvPut("ezbr_initialized", "yes") dbKvPut("ezbr_initialized", "yes")
@ -52,9 +52,13 @@ func StartAppService(errch chan error) error {
router.HandleFunc("/_matrix/app/v1/transactions/{txnId}", handleTxn) router.HandleFunc("/_matrix/app/v1/transactions/{txnId}", handleTxn)
router.HandleFunc("/transactions/{txnId}", handleTxn) router.HandleFunc("/transactions/{txnId}", handleTxn)
go func() {
log.Printf("Starting HTTP server on %s", config.ASBindAddr) log.Printf("Starting HTTP server on %s", config.ASBindAddr)
err := http.ListenAndServe(config.ASBindAddr, checkTokenAndLog(router)) http_server := &http.Server{
Addr: config.ASBindAddr,
Handler: checkTokenAndLog(router),
}
go func() {
err := http_server.ListenAndServe()
if err != nil { if err != nil {
errch <- err errch <- err
} }
@ -71,7 +75,7 @@ func StartAppService(errch chan error) error {
} }
}() }()
return nil return http_server, nil
} }
func checkTokenAndLog(handler http.Handler) http.Handler { func checkTokenAndLog(handler http.Handler) http.Handler {

10
web.go
View file

@ -21,7 +21,7 @@ const SESSION_NAME = "easybridge_session"
var sessionsStore sessions.Store = nil var sessionsStore sessions.Store = nil
var userKeys = map[string]*[32]byte{} var userKeys = map[string]*[32]byte{}
func StartWeb(errch chan error) { func StartWeb(errch chan error) *http.Server {
session_key := blake2b.Sum256([]byte(config.SessionKey)) session_key := blake2b.Sum256([]byte(config.SessionKey))
sessionsStore = sessions.NewCookieStore(session_key[:]) sessionsStore = sessions.NewCookieStore(session_key[:])
@ -36,12 +36,18 @@ func StartWeb(errch chan error) {
r.Handle("/static/{file:.*}", http.StripPrefix("/static/", staticfiles)) r.Handle("/static/{file:.*}", http.StripPrefix("/static/", staticfiles))
log.Printf("Starting web UI HTTP server on %s", config.WebBindAddr) log.Printf("Starting web UI HTTP server on %s", config.WebBindAddr)
web_server := &http.Server{
Addr: config.WebBindAddr,
Handler: logRequest(r),
}
go func() { go func() {
err := http.ListenAndServe(config.WebBindAddr, logRequest(r)) err := web_server.ListenAndServe()
if err != nil { if err != nil {
errch <- err errch <- err
} }
}() }()
return web_server
} }
func logRequest(handler http.Handler) http.Handler { func logRequest(handler http.Handler) http.Handler {