forked from Deuxfleurs/bottin
271 lines
6.8 KiB
Go
271 lines
6.8 KiB
Go
package ldapserver
|
|
|
|
import (
|
|
"bufio"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
ldap "github.com/vjeantet/goldap/message"
|
|
)
|
|
|
|
type UserState interface{}
|
|
|
|
type client struct {
|
|
Numero int
|
|
srv *Server
|
|
rwc net.Conn
|
|
br *bufio.Reader
|
|
bw *bufio.Writer
|
|
chanOut chan *ldap.LDAPMessage
|
|
wg sync.WaitGroup
|
|
closing chan bool
|
|
requestList map[int]*Message
|
|
mutex sync.Mutex
|
|
writeDone chan bool
|
|
rawData []byte
|
|
userState UserState
|
|
}
|
|
|
|
func (c *client) GetConn() net.Conn {
|
|
return c.rwc
|
|
}
|
|
|
|
func (c *client) GetRaw() []byte {
|
|
return c.rawData
|
|
}
|
|
|
|
func (c *client) SetConn(conn net.Conn) {
|
|
c.rwc = conn
|
|
c.br = bufio.NewReader(c.rwc)
|
|
c.bw = bufio.NewWriter(c.rwc)
|
|
}
|
|
|
|
func (c *client) GetMessageByID(messageID int) (*Message, bool) {
|
|
if requestToAbandon, ok := c.requestList[messageID]; ok {
|
|
return requestToAbandon, true
|
|
}
|
|
return nil, false
|
|
}
|
|
|
|
func (c *client) Addr() net.Addr {
|
|
return c.rwc.RemoteAddr()
|
|
}
|
|
|
|
func (c *client) ReadPacket() (*messagePacket, error) {
|
|
mP, err := readMessagePacket(c.br)
|
|
c.rawData = make([]byte, len(mP.bytes))
|
|
copy(c.rawData, mP.bytes)
|
|
return mP, err
|
|
}
|
|
|
|
func (c *client) serve() {
|
|
defer c.close()
|
|
|
|
c.closing = make(chan bool)
|
|
if onc := c.srv.OnNewConnection; onc != nil {
|
|
if err := onc(c.rwc); err != nil {
|
|
Logger.Printf("Erreur OnNewConnection: %s", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Create the ldap response queue to be writted to client (buffered to 20)
|
|
// buffered to 20 means that If client is slow to handler responses, Server
|
|
// Handlers will stop to send more respones
|
|
c.chanOut = make(chan *ldap.LDAPMessage)
|
|
c.writeDone = make(chan bool)
|
|
// for each message in c.chanOut send it to client
|
|
go func() {
|
|
for msg := range c.chanOut {
|
|
c.writeMessage(msg)
|
|
}
|
|
close(c.writeDone)
|
|
}()
|
|
|
|
// Listen for server signal to shutdown
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-c.srv.chDone: // server signals shutdown process
|
|
c.wg.Add(1)
|
|
r := NewExtendedResponse(LDAPResultUnwillingToPerform)
|
|
r.SetDiagnosticMessage("server is about to stop")
|
|
r.SetResponseName(NoticeOfDisconnection)
|
|
|
|
m := ldap.NewLDAPMessageWithProtocolOp(r)
|
|
|
|
c.chanOut <- m
|
|
c.wg.Done()
|
|
c.rwc.SetReadDeadline(time.Now().Add(time.Millisecond))
|
|
return
|
|
case <-c.closing:
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
c.requestList = make(map[int]*Message)
|
|
|
|
for {
|
|
|
|
if c.srv.ReadTimeout != 0 {
|
|
c.rwc.SetReadDeadline(time.Now().Add(c.srv.ReadTimeout))
|
|
}
|
|
if c.srv.WriteTimeout != 0 {
|
|
c.rwc.SetWriteDeadline(time.Now().Add(c.srv.WriteTimeout))
|
|
}
|
|
|
|
//Read client input as a ASN1/BER binary message
|
|
messagePacket, err := c.ReadPacket()
|
|
if err != nil {
|
|
if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
|
|
Logger.Printf("Sorry client %d, i can not wait anymore (reading timeout) ! %s", c.Numero, err)
|
|
} else {
|
|
Logger.Printf("Error readMessagePacket: %s", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
//Convert ASN1 binaryMessage to a ldap Message
|
|
message, err := messagePacket.readMessage()
|
|
|
|
if err != nil {
|
|
Logger.Printf("Error reading Message : %s\n\t%x", err.Error(), messagePacket.bytes)
|
|
continue
|
|
}
|
|
if TRACE {
|
|
//Logger.Printf("<<< %d - %s - hex=%x", c.Numero, message.ProtocolOpName(), messagePacket)
|
|
Logger.Printf("<<< %d - %#v", c.Numero, message)
|
|
}
|
|
|
|
// TODO: Use a implementation to limit runnuning request by client
|
|
// solution 1 : when the buffered output channel is full, send a busy
|
|
// solution 2 : when 10 client requests (goroutines) are running, send a busy message
|
|
// And when the limit is reached THEN send a BusyLdapMessage
|
|
|
|
// When message is an UnbindRequest, stop serving
|
|
if _, ok := message.ProtocolOp().(ldap.UnbindRequest); ok {
|
|
return
|
|
}
|
|
|
|
// If client requests a startTls, do not handle it in a
|
|
// goroutine, connection has to remain free until TLS is OK
|
|
// @see RFC https://tools.ietf.org/html/rfc4511#section-4.14.1
|
|
if req, ok := message.ProtocolOp().(ldap.ExtendedRequest); ok {
|
|
if req.RequestName() == NoticeOfStartTLS {
|
|
c.wg.Add(1)
|
|
c.ProcessRequestMessage(&message)
|
|
continue
|
|
}
|
|
}
|
|
|
|
// TODO: go/non go routine choice should be done in the ProcessRequestMessage
|
|
// not in the client.serve func
|
|
c.wg.Add(1)
|
|
go c.ProcessRequestMessage(&message)
|
|
}
|
|
|
|
}
|
|
|
|
// close closes client,
|
|
// * stop reading from client
|
|
// * signals to all currently running request processor to stop
|
|
// * wait for all request processor to end
|
|
// * close client connection
|
|
// * signal to server that client shutdown is ok
|
|
func (c *client) close() {
|
|
if DEBUG {
|
|
Logger.Printf("client %d close()", c.Numero)
|
|
}
|
|
close(c.closing)
|
|
|
|
// stop reading from client
|
|
c.rwc.SetReadDeadline(time.Now().Add(time.Millisecond))
|
|
if DEBUG {
|
|
Logger.Printf("client %d close() - stop reading from client", c.Numero)
|
|
}
|
|
|
|
// signals to all currently running request processor to stop
|
|
c.mutex.Lock()
|
|
for messageID, request := range c.requestList {
|
|
Logger.Printf("Client %d close() - sent abandon signal to request[messageID = %d]", c.Numero, messageID)
|
|
go request.Abandon()
|
|
}
|
|
c.mutex.Unlock()
|
|
if DEBUG {
|
|
Logger.Printf("client %d close() - Abandon signal sent to processors", c.Numero)
|
|
}
|
|
|
|
c.wg.Wait() // wait for all current running request processor to end
|
|
close(c.chanOut) // No more message will be sent to client, close chanOUT
|
|
if DEBUG {
|
|
Logger.Printf("client [%d] request processors ended", c.Numero)
|
|
}
|
|
|
|
<-c.writeDone // Wait for the last message sent to be written
|
|
c.rwc.Close() // close client connection
|
|
Logger.Printf("client [%d] connection closed", c.Numero)
|
|
|
|
c.srv.wg.Done() // signal to server that client shutdown is ok
|
|
}
|
|
|
|
func (c *client) writeMessage(m *ldap.LDAPMessage) {
|
|
data, _ := m.Write()
|
|
if TRACE {
|
|
//Logger.Printf(">>> %d - %s - hex=%x", c.Numero, m.ProtocolOpName(), data.Bytes())
|
|
Logger.Printf(">>> %d - %#v", c.Numero, m)
|
|
}
|
|
c.bw.Write(data.Bytes())
|
|
c.bw.Flush()
|
|
}
|
|
|
|
// ResponseWriter interface is used by an LDAP handler to
|
|
// construct an LDAP response.
|
|
type ResponseWriter interface {
|
|
// Write writes the LDAPResponse to the connection as part of an LDAP reply.
|
|
Write(po ldap.ProtocolOp)
|
|
}
|
|
|
|
type responseWriterImpl struct {
|
|
chanOut chan *ldap.LDAPMessage
|
|
messageID int
|
|
}
|
|
|
|
func (w responseWriterImpl) Write(po ldap.ProtocolOp) {
|
|
m := ldap.NewLDAPMessageWithProtocolOp(po)
|
|
m.SetMessageID(w.messageID)
|
|
w.chanOut <- m
|
|
}
|
|
|
|
func (c *client) ProcessRequestMessage(message *ldap.LDAPMessage) {
|
|
defer c.wg.Done()
|
|
|
|
var m Message
|
|
m = Message{
|
|
LDAPMessage: message,
|
|
Done: make(chan bool, 2),
|
|
Client: c,
|
|
}
|
|
|
|
c.registerRequest(&m)
|
|
defer c.unregisterRequest(&m)
|
|
|
|
var w responseWriterImpl
|
|
w.chanOut = c.chanOut
|
|
w.messageID = m.MessageID().Int()
|
|
|
|
c.srv.Handler.ServeLDAP(c.userState, w, &m)
|
|
}
|
|
|
|
func (c *client) registerRequest(m *Message) {
|
|
c.mutex.Lock()
|
|
c.requestList[m.MessageID().Int()] = m
|
|
c.mutex.Unlock()
|
|
}
|
|
|
|
func (c *client) unregisterRequest(m *Message) {
|
|
c.mutex.Lock()
|
|
delete(c.requestList, m.MessageID().Int())
|
|
c.mutex.Unlock()
|
|
}
|