diff --git a/handlers.go b/handlers.go index 869e1c1..4e127b4 100644 --- a/handlers.go +++ b/handlers.go @@ -81,14 +81,14 @@ func handleLogin(ectx echo.Context) error { username := ctx.FormValue("username") password := ctx.FormValue("password") if username != "" && password != "" { - token, err := ctx.server.sessions.Put(username, password) + s, err := ctx.server.sessions.Put(username, password) if err != nil { if _, ok := err.(AuthError); ok { return ctx.Render(http.StatusOK, "login.html", nil) } return fmt.Errorf("failed to put connection in pool: %v", err) } - ctx.setToken(token) + ctx.setToken(s.Token) return ctx.Redirect(http.StatusFound, "/mailbox/INBOX") } @@ -99,13 +99,7 @@ func handleLogin(ectx echo.Context) error { func handleLogout(ectx echo.Context) error { ctx := ectx.(*context) - err := ctx.session.Do(func(c *imapclient.Client) error { - return c.Logout() - }) - if err != nil { - return fmt.Errorf("failed to logout: %v", err) - } - + ctx.session.Close() ctx.setToken("") return ctx.Redirect(http.StatusFound, "/login") } diff --git a/server.go b/server.go index 94cc133..0ab7f24 100644 --- a/server.go +++ b/server.go @@ -173,6 +173,7 @@ func New(e *echo.Echo, options *Options) error { } else if err != nil { return err } + ctx.session.Ping() return next(ctx) } diff --git a/session.go b/session.go index 0292d3a..8c8eecf 100644 --- a/session.go +++ b/session.go @@ -6,10 +6,14 @@ import ( "errors" "fmt" "sync" + "time" imapclient "github.com/emersion/go-imap/client" ) +// TODO: make this configurable +const sessionDuration = 30 * time.Minute + func generateToken() (string, error) { b := make([]byte, 32) _, err := rand.Read(b) @@ -30,23 +34,52 @@ func (err AuthError) Error() string { } type Session struct { - locker sync.Mutex - imapConn *imapclient.Client + Token string + + manager *SessionManager username, password string + closed chan struct{} + pings chan struct{} + timer *time.Timer + + locker sync.Mutex + imapConn *imapclient.Client // protected by locker, can be nil +} + +func (s *Session) Ping() { + s.pings <- struct{}{} } func (s *Session) Do(f func(*imapclient.Client) error) error { s.locker.Lock() defer s.locker.Unlock() + if s.imapConn == nil { + var err error + s.imapConn, err = s.manager.connect(s.username, s.password) + if err != nil { + s.Close() + return fmt.Errorf("failed to re-connect to IMAP server: %v", err) + } + } + return f(s.imapConn) } -// TODO: expiration timer +func (s *Session) Close() { + select { + case <-s.closed: + // This space is intentionally left blank + default: + close(s.closed) + } +} + type SessionManager struct { - locker sync.Mutex - sessions map[string]*Session newIMAPClient func() (*imapclient.Client, error) + + locker sync.Mutex + sessions map[string]*Session // protected by locker } func NewSessionManager(newIMAPClient func() (*imapclient.Client, error)) *SessionManager { @@ -81,21 +114,22 @@ func (sm *SessionManager) Get(token string) (*Session, error) { return session, nil } -func (sm *SessionManager) Put(username, password string) (token string, err error) { +func (sm *SessionManager) Put(username, password string) (*Session, error) { c, err := sm.connect(username, password) if err != nil { - return "", err + return nil, err } sm.locker.Lock() defer sm.locker.Unlock() + var token string for { var err error token, err = generateToken() if err != nil { c.Logout() - return "", err + return nil, err } if _, ok := sm.sessions[token]; !ok { @@ -103,19 +137,60 @@ func (sm *SessionManager) Put(username, password string) (token string, err erro } } - sm.sessions[token] = &Session{ + s := &Session{ + Token: token, + manager: sm, + closed: make(chan struct{}), + pings: make(chan struct{}, 5), imapConn: c, username: username, password: password, } + sm.sessions[token] = s go func() { - <-c.LoggedOut() + timer := time.NewTimer(sessionDuration) + + alive := true + for alive { + var loggedOut <-chan struct{} + s.locker.Lock() + if s.imapConn != nil { + loggedOut = s.imapConn.LoggedOut() + } + s.locker.Unlock() + + select { + case <-loggedOut: + s.locker.Lock() + s.imapConn = nil + s.locker.Unlock() + case <-s.pings: + if !timer.Stop() { + <-timer.C + } + timer.Reset(sessionDuration) + case <-timer.C: + alive = false + case <-s.closed: + alive = false + } + } + + if !timer.Stop() { + <-timer.C + } + + s.locker.Lock() + if s.imapConn != nil { + s.imapConn.Logout() + } + s.locker.Unlock() sm.locker.Lock() delete(sm.sessions, token) sm.locker.Unlock() }() - return token, nil + return s, nil }