package main import ( "fmt" "strings" ldap "./ldapserver" message "github.com/vjeantet/goldap/message" ) // Compare request ------------------------- func (server *Server) handleCompare(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) { state := s.(*State) r := m.GetCompareRequest() code, err := server.handleCompareInternal(state, &r) res := ldap.NewResponse(code) if err != nil { res.SetDiagnosticMessage(err.Error()) } w.Write(message.CompareResponse(res)) } func (server *Server) handleCompareInternal(state *State, r *message.CompareRequest) (int, error) { dn := string(r.Entry()) attr := string(r.Ava().AttributeDesc()) expected := string(r.Ava().AssertionValue()) _, err := server.checkSuffix(dn, false) if err != nil { return ldap.LDAPResultInvalidDNSyntax, err } // Check permissions if !server.config.Acl.Check(&state.login, dn, "read", []string{attr}) { return ldap.LDAPResultInsufficientAccessRights, nil } // Do query exists, err := server.objectExists(dn) if err != nil { return ldap.LDAPResultOperationsError, err } if !exists { return ldap.LDAPResultNoSuchObject, fmt.Errorf("Not found: %s", dn) } values, err := server.getAttribute(dn, attr) if err != nil { return ldap.LDAPResultOperationsError, err } for _, v := range values { if v == expected { return ldap.LDAPResultCompareTrue, nil } } return ldap.LDAPResultCompareFalse, nil } // Search request ------------------------- func (server *Server) handleSearch(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) { state := s.(*State) r := m.GetSearchRequest() code, err := server.handleSearchInternal(state, w, &r) res := ldap.NewResponse(code) if err != nil { res.SetDiagnosticMessage(err.Error()) } if code != ldap.LDAPResultSuccess { server.logger.Printf("Failed to do search %#v (%s)", r, err) } w.Write(message.SearchResultDone(res)) } func (server *Server) handleSearchInternal(state *State, w ldap.ResponseWriter, r *message.SearchRequest) (int, error) { if DEBUG { server.logger.Printf("-- SEARCH REQUEST: --") server.logger.Printf("Request BaseDn=%s", r.BaseObject()) server.logger.Printf("Request Filter=%s", r.Filter()) server.logger.Printf("Request FilterString=%s", r.FilterString()) server.logger.Printf("Request Attributes=%s", r.Attributes()) server.logger.Printf("Request TimeLimit=%d", r.TimeLimit().Int()) } if !server.config.Acl.Check(&state.login, "read", string(r.BaseObject()), []string{}) { return ldap.LDAPResultInsufficientAccessRights, fmt.Errorf("Please specify a base object on which you have read rights") } baseObject, err := server.checkSuffix(string(r.BaseObject()), true) if err != nil { return ldap.LDAPResultInvalidDNSyntax, err } basePath, err := dnToConsul(baseObject) if err != nil { return ldap.LDAPResultInvalidDNSyntax, err } data, _, err := server.kv.List(basePath+"/", nil) if err != nil { return ldap.LDAPResultOperationsError, err } entries, err := parseConsulResult(data) if err != nil { return ldap.LDAPResultOperationsError, err } if DEBUG { server.logger.Printf("in %s: %#v", basePath+"/", data) server.logger.Printf("%#v", entries) } for dn, entry := range entries { // Filter out if we don't match requested filter matched, err := applyFilter(entry, r.Filter()) if err != nil { return ldap.LDAPResultUnwillingToPerform, err } if !matched { continue } // Filter out if user is not allowed to read this if !server.config.Acl.Check(&state.login, "read", dn, []string{}) { continue } e := ldap.NewSearchResultEntry(dn) for attr, val := range entry { // If attribute is not in request, exclude it from returned entry if len(r.Attributes()) > 0 { found := false for _, requested := range r.Attributes() { if strings.EqualFold(string(requested), attr) { found = true break } } if !found { continue } } // If we are not allowed to read attribute, exclude it from returned entry if !server.config.Acl.Check(&state.login, "read", dn, []string{attr}) { continue } // Send result for _, v := range val { e.AddAttribute(message.AttributeDescription(attr), message.AttributeValue(v)) } } w.Write(e) } return ldap.LDAPResultSuccess, nil } func applyFilter(entry Entry, filter message.Filter) (bool, error) { if fAnd, ok := filter.(message.FilterAnd); ok { for _, cond := range fAnd { res, err := applyFilter(entry, cond) if err != nil { return false, err } if !res { return false, nil } } return true, nil } else if fOr, ok := filter.(message.FilterOr); ok { for _, cond := range fOr { res, err := applyFilter(entry, cond) if err != nil { return false, err } if res { return true, nil } } return false, nil } else if fNot, ok := filter.(message.FilterNot); ok { res, err := applyFilter(entry, fNot.Filter) if err != nil { return false, err } return !res, nil } else if fPresent, ok := filter.(message.FilterPresent); ok { what := string(fPresent) // Case insensitive search for desc, values := range entry { if strings.EqualFold(what, desc) { return len(values) > 0, nil } } return false, nil } else if fEquality, ok := filter.(message.FilterEqualityMatch); ok { desc := string(fEquality.AttributeDesc()) target := string(fEquality.AssertionValue()) // Case insensitive attribute search for entry_desc, value := range entry { if strings.EqualFold(entry_desc, desc) { for _, val := range value { if val == target { return true, nil } } return false, nil } } return false, nil } else { return false, fmt.Errorf("Unsupported filter: %#v %T", filter, filter) } }