package main import ( "fmt" "strings" ldap "bottin/ldapserver" message "github.com/lor00x/goldap/message" ) // Generic read utility functions ---------- func (server *Server) getAttribute(dn string, attr string) ([]string, error) { path, err := dnToConsul(dn) if err != nil { return nil, err } pairs, _, err := server.kv.List(path+"/attribute=", &server.readOpts) if err != nil { return nil, err } values := []string{} for _, pair := range pairs { if strings.EqualFold(pair.Key, path+"/attribute="+attr) { newVals, err := parseValue(pair.Value) if err != nil { return nil, err } values = append(values, newVals...) } } return values, nil } func (server *Server) objectExists(dn string) (bool, error) { prefix, err := dnToConsul(dn) if err != nil { return false, err } data, _, err := server.kv.List(prefix+"/attribute=", &server.readOpts) if err != nil { return false, err } return len(data) > 0, nil } // Compare request ------------------------- func (server *Server) handleCompare(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) { state := s.(*State) r := m.GetCompareRequest() code, err := server.handleCompareInternal(state, &r) res := ldap.NewResponse(code) if err != nil { res.SetDiagnosticMessage(err.Error()) } w.Write(message.CompareResponse(res)) } func (server *Server) handleCompareInternal(state *State, r *message.CompareRequest) (int, error) { attr := string(r.Ava().AttributeDesc()) expected := string(r.Ava().AssertionValue()) dn, err := server.checkDN(string(r.Entry()), false) if err != nil { return ldap.LDAPResultInvalidDNSyntax, err } // Check permissions if !server.config.Acl.Check(&state.login, dn, "read", []string{attr}) { return ldap.LDAPResultInsufficientAccessRights, nil } // Do query exists, err := server.objectExists(dn) if err != nil { return ldap.LDAPResultOperationsError, err } if !exists { return ldap.LDAPResultNoSuchObject, fmt.Errorf("Not found: %s", dn) } values, err := server.getAttribute(dn, attr) if err != nil { return ldap.LDAPResultOperationsError, err } for _, v := range values { if valueMatch(attr, v, expected) { return ldap.LDAPResultCompareTrue, nil } } return ldap.LDAPResultCompareFalse, nil } // Search request ------------------------- func (server *Server) handleSearch(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) { state := s.(*State) r := m.GetSearchRequest() code, err := server.handleSearchInternal(state, w, &r) res := ldap.NewResponse(code) if err != nil { res.SetDiagnosticMessage(err.Error()) } if code != ldap.LDAPResultSuccess { server.logger.Printf("Failed to do search %#v (%s)", r, err) } w.Write(message.SearchResultDone(res)) } func (server *Server) handleSearchInternal(state *State, w ldap.ResponseWriter, r *message.SearchRequest) (int, error) { baseObject, err := server.checkDN(string(r.BaseObject()), true) if err != nil { return ldap.LDAPResultInvalidDNSyntax, err } server.logger.Tracef("-- SEARCH REQUEST: --") server.logger.Tracef("Request BaseDn=%s", baseObject) server.logger.Tracef("Request Filter=%s", r.Filter()) server.logger.Tracef("Request FilterString=%s", r.FilterString()) server.logger.Tracef("Request Attributes=%s", r.Attributes()) server.logger.Tracef("Request TimeLimit=%d", r.TimeLimit().Int()) if !server.config.Acl.Check(&state.login, "read", baseObject, []string{}) { return ldap.LDAPResultInsufficientAccessRights, fmt.Errorf("Please specify a base object on which you have read rights") } baseObjectLevel := len(strings.Split(baseObject, ",")) basePath, err := dnToConsul(baseObject) if err != nil { return ldap.LDAPResultInvalidDNSyntax, err } if r.Scope() == message.SearchRequestScopeBaseObject { basePath += "/attribute=" } else { basePath += "/" } data, _, err := server.kv.List(basePath, &server.readOpts) if err != nil { return ldap.LDAPResultOperationsError, err } entries, err := parseConsulResult(data) if err != nil { return ldap.LDAPResultOperationsError, err } server.logger.Tracef("in %s: %#v", basePath, data) server.logger.Tracef("%#v", entries) for dn, entry := range entries { if r.Scope() == message.SearchRequestScopeBaseObject { if dn != baseObject { continue } } else if r.Scope() == message.SearchRequestSingleLevel { objectLevel := len(strings.Split(dn, ",")) if objectLevel != baseObjectLevel+1 { continue } } // Filter out if we don't match requested filter matched, err := applyFilter(entry, r.Filter()) if err != nil { return ldap.LDAPResultUnwillingToPerform, err } if !matched { continue } // Filter out if user is not allowed to read this if !server.config.Acl.Check(&state.login, "read", dn, []string{}) { continue } e := ldap.NewSearchResultEntry(dn) for attr, val := range entry { // If attribute is not in request, exclude it from returned entry if len(r.Attributes()) > 0 { found := false for _, requested := range r.Attributes() { if string(requested) == "1.1" && len(r.Attributes()) == 1 { break } if string(requested) == "*" || strings.EqualFold(string(requested), attr) { found = true break } } if !found { continue } } // If we are not allowed to read attribute, exclude it from returned entry if !server.config.Acl.Check(&state.login, "read", dn, []string{attr}) { continue } // Send result for _, v := range val { e.AddAttribute(message.AttributeDescription(attr), message.AttributeValue(v)) } } w.Write(e) } return ldap.LDAPResultSuccess, nil } func applyFilter(entry Entry, filter message.Filter) (bool, error) { if fAnd, ok := filter.(message.FilterAnd); ok { for _, cond := range fAnd { res, err := applyFilter(entry, cond) if err != nil { return false, err } if !res { return false, nil } } return true, nil } else if fOr, ok := filter.(message.FilterOr); ok { for _, cond := range fOr { res, err := applyFilter(entry, cond) if err != nil { return false, err } if res { return true, nil } } return false, nil } else if fNot, ok := filter.(message.FilterNot); ok { res, err := applyFilter(entry, fNot.Filter) if err != nil { return false, err } return !res, nil } else if fPresent, ok := filter.(message.FilterPresent); ok { what := string(fPresent) // Case insensitive search for desc, values := range entry { if strings.EqualFold(what, desc) { return len(values) > 0, nil } } return false, nil } else if fEquality, ok := filter.(message.FilterEqualityMatch); ok { desc := string(fEquality.AttributeDesc()) target := string(fEquality.AssertionValue()) // Case insensitive attribute search for entry_desc, value := range entry { if strings.EqualFold(entry_desc, desc) { for _, val := range value { if valueMatch(entry_desc, val, target) { return true, nil } } return false, nil } } return false, nil } else { return false, fmt.Errorf("Unsupported filter: %#v %T", filter, filter) } }