diff --git a/account.go b/account.go index b4e7118..cdcccfe 100644 --- a/account.go +++ b/account.go @@ -130,6 +130,7 @@ func RemoveAccount(mxUser string, name string) { func CloseAllAcountsForShutdown() { accountsLock.Lock() + defer accountsLock.Unlock() for _, accl := range registeredAccounts { for _, acct := range accl { acct.Conn.Close() diff --git a/connector/external/external.go b/connector/external/external.go index 6117880..b66ab6c 100644 --- a/connector/external/external.go +++ b/connector/external/external.go @@ -79,8 +79,8 @@ type External struct { config Configuration - recv io.Reader - send io.Writer + recvPipe io.ReadCloser + sendPipe io.WriteCloser sendJson *json.Encoder generation int @@ -114,7 +114,7 @@ func (ext *External) Configure(c Configuration) error { ext.handlerChan = make(chan *extMessageWithData, 1000) go ext.handlerLoop(ext.generation) - err = ext.setupProc() + err = ext.setupProc(ext.generation) if err != nil { return err } @@ -133,26 +133,28 @@ func (ext *External) Configure(c Configuration) error { // ---- Process management and communication logic -func (ext *External) setupProc() error { +func (ext *External) setupProc(generation int) error { var err error ext.proc = exec.Command(ext.command) - ext.recv, err = ext.proc.StdoutPipe() + ext.recvPipe, err = ext.proc.StdoutPipe() if err != nil { return err } - ext.send, err = ext.proc.StdinPipe() + ext.sendPipe, err = ext.proc.StdinPipe() if err != nil { return err } + send := io.Writer(ext.sendPipe) + recv := io.Reader(ext.recvPipe) if ext.debug { - ext.recv = io.TeeReader(ext.recv, os.Stderr) - ext.send = io.MultiWriter(ext.send, os.Stderr) + recv = io.TeeReader(recv, os.Stderr) + send = io.MultiWriter(send, os.Stderr) } - ext.sendJson = json.NewEncoder(ext.send) + ext.sendJson = json.NewEncoder(send) ext.proc.Stderr = os.Stderr @@ -161,7 +163,7 @@ func (ext *External) setupProc() error { return err } - go ext.recvLoop() + go ext.recvLoop(recv, generation) return nil } @@ -175,7 +177,7 @@ func (ext *External) restartLoop(generation int) { break } log.Printf("Process %s stopped, restarting.", ext.command) - err := ext.setupProc() + err := ext.setupProc(generation) if err != nil { ext.proc = nil log.Warnf("Unable to restart %s: %s", ext.command, err) @@ -232,8 +234,8 @@ func (m *extMessageWithData) UnmarshalJSON(jj []byte) error { } -func (ext *External) recvLoop() { - scanner := bufio.NewScanner(ext.recv) +func (ext *External) recvLoop(from io.Reader, generation int) { + scanner := bufio.NewScanner(from) for scanner.Scan() { var msg extMessageWithData err := json.Unmarshal(scanner.Bytes(), &msg) @@ -260,6 +262,10 @@ func (ext *External) recvLoop() { } else { ext.handlerChan <- &msg } + + if ext.generation != generation { + break + } } } @@ -315,20 +321,25 @@ func (ext *External) cmd(msg extMessage, data interface{}) (*extMessageWithData, } func (ext *External) Close() { + ext.generation += 1 + ext.sendJson.Encode(&extMessage{ MsgType: CLOSE, }) - ext.generation += 1 proc := ext.proc + proc.Process.Signal(os.Interrupt) + ext.recvPipe.Close() + ext.sendPipe.Close() + ext.proc = nil - ext.recv = nil - ext.send = nil + ext.recvPipe = nil + ext.sendPipe = nil ext.sendJson = nil ext.handlerChan = nil go func() { - time.Sleep(10 * time.Second) + time.Sleep(1 * time.Second) proc.Process.Kill() }() } diff --git a/main.go b/main.go index d6597a2..a5b4394 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "crypto/rand" "encoding/hex" "encoding/json" @@ -9,6 +10,7 @@ import ( "os" "os/signal" "syscall" + "time" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" @@ -179,12 +181,12 @@ func main() { sigch := make(chan os.Signal) signal.Notify(sigch, os.Interrupt, syscall.SIGTERM) - err = StartAppService(errch) + as_server, err := StartAppService(errch) if err != nil { log.Fatal(err) } - StartWeb(errch) + web_server := StartWeb(errch) // Wait for an error somewhere or interrupt signal select { @@ -196,6 +198,22 @@ func main() { log.Warnf("Got signal %s", sig.String()) } + // Shut down, hopefully this is not a too bad way to do it log.Warn("Shuttind down") + delay := 2 * time.Second + + ctx1, _ := context.WithTimeout(context.TODO(), delay) + go as_server.Shutdown(ctx1) + + ctx2, _ := context.WithTimeout(context.TODO(), delay) + go web_server.Shutdown(ctx2) + + time.Sleep(delay) CloseAllAcountsForShutdown() + + if err != nil { + os.Exit(1) + } else { + os.Exit(0) + } } diff --git a/server.go b/server.go index 65f4f78..ce0334a 100644 --- a/server.go +++ b/server.go @@ -16,18 +16,18 @@ import ( var mx *mxlib.Client -func StartAppService(errch chan error) error { +func StartAppService(errch chan error) (*http.Server, error) { mx = mxlib.NewClient(config.Server, registration.AsToken) err := InitDb() if err != nil { - return err + return nil, err } if dbKvGet("ezbr_initialized") != "yes" { err = mx.RegisterUser(registration.SenderLocalpart) if mxe, ok := err.(*mxlib.MxError); !ok || mxe.ErrCode != "M_USER_IN_USE" { - return err + return nil, err } _, st := os.Stat(config.AvatarFile) @@ -36,13 +36,13 @@ func StartAppService(errch chan error) error { Path: config.AvatarFile, }) if err != nil { - return err + return nil, err } } err = mx.ProfileDisplayname(ezbrMxId(), fmt.Sprintf("Easybridge (%s)", EASYBRIDGE_SYSTEM_PROTOCOL)) if err != nil { - return err + return nil, err } dbKvPut("ezbr_initialized", "yes") @@ -52,9 +52,13 @@ func StartAppService(errch chan error) error { router.HandleFunc("/_matrix/app/v1/transactions/{txnId}", handleTxn) router.HandleFunc("/transactions/{txnId}", handleTxn) + log.Printf("Starting HTTP server on %s", config.ASBindAddr) + http_server := &http.Server{ + Addr: config.ASBindAddr, + Handler: checkTokenAndLog(router), + } go func() { - log.Printf("Starting HTTP server on %s", config.ASBindAddr) - err := http.ListenAndServe(config.ASBindAddr, checkTokenAndLog(router)) + err := http_server.ListenAndServe() if err != nil { errch <- err } @@ -71,7 +75,7 @@ func StartAppService(errch chan error) error { } }() - return nil + return http_server, nil } func checkTokenAndLog(handler http.Handler) http.Handler { diff --git a/web.go b/web.go index d5846c5..2a04a6c 100644 --- a/web.go +++ b/web.go @@ -21,7 +21,7 @@ const SESSION_NAME = "easybridge_session" var sessionsStore sessions.Store = nil var userKeys = map[string]*[32]byte{} -func StartWeb(errch chan error) { +func StartWeb(errch chan error) *http.Server { session_key := blake2b.Sum256([]byte(config.SessionKey)) sessionsStore = sessions.NewCookieStore(session_key[:]) @@ -36,12 +36,18 @@ func StartWeb(errch chan error) { r.Handle("/static/{file:.*}", http.StripPrefix("/static/", staticfiles)) log.Printf("Starting web UI HTTP server on %s", config.WebBindAddr) + web_server := &http.Server{ + Addr: config.WebBindAddr, + Handler: logRequest(r), + } go func() { - err := http.ListenAndServe(config.WebBindAddr, logRequest(r)) + err := web_server.ListenAndServe() if err != nil { errch <- err } }() + + return web_server } func logRequest(handler http.Handler) http.Handler {