From 4c5b3d929d327d6703d0aeafe8149c25a845a9f8 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Sun, 19 Jan 2020 17:55:25 +0100 Subject: [PATCH] Implement Search with basic filter support --- main.go | 212 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 206 insertions(+), 6 deletions(-) diff --git a/main.go b/main.go index 1fc5d3b..745d3fd 100644 --- a/main.go +++ b/main.go @@ -27,6 +27,62 @@ func dnToConsul(dn string) string { return strings.Join(rdns, "/") } +func consulToDN(pair *consul.KVPair) (string, string, []byte) { + path := strings.Split(pair.Key, "/") + dn := "" + for _, cpath := range path { + if cpath == "" { + continue + } + kv := strings.Split(cpath, "=") + if len(kv) == 2 && kv[0] == "attribute" { + return dn, kv[1], pair.Value + } + if dn != "" { + dn = "," + dn + } + dn = cpath + dn + } + return dn, "", nil +} + +func parseConsulResult(data []*consul.KVPair) (map[string]Entry, error) { + aggregator := map[string]Entry{} + + for _, kv := range data { + log.Printf("%s %s", kv.Key, string(kv.Value)) + dn, attr, val := consulToDN(kv) + if attr == "" || val == nil { + continue + } + if _, exists := aggregator[dn]; !exists { + aggregator[dn] = Entry{} + } + var value interface{} + err := json.Unmarshal(val, &value) + if err != nil { + return nil, err + } + if vlist, ok := value.([]interface{}); ok { + vlist2 := []string{} + for _, v := range vlist { + if vstr, ok := v.(string); ok { + vlist2 = append(vlist2, vstr) + } else { + return nil, fmt.Errorf("Not a string: %#v", v) + } + } + aggregator[dn][attr] = vlist2 + } else if vstr, ok := value.(string); ok { + aggregator[dn][attr] = vstr + } else { + return nil, fmt.Errorf("Not a string or a list of strings: %#v", value) + } + } + + return aggregator, nil +} + type DNComponent struct { Type string Value string @@ -63,7 +119,7 @@ type State struct { bindDn string } -type Attributes map[string]interface{} +type Entry map[string]interface{} func main() { //ldap logger @@ -95,6 +151,7 @@ func main() { routes := ldap.NewRouteMux() routes.Bind(gobottin.handleBind) + routes.Search(gobottin.handleSearch) ldapserver.Handle(routes) // listen on 10389 @@ -120,7 +177,7 @@ func (server *Server) init() error { return nil } - base_attributes := Attributes{ + base_attributes := Entry{ "objectClass": []string{"top", "dcObject", "organization"}, "structuralObjectClass": "Organization", } @@ -141,7 +198,7 @@ func (server *Server) init() error { admin_pass_hash := SSHAEncode([]byte(admin_pass_str)) admin_dn := "cn=admin," + server.config.Suffix - admin_attributes := Attributes{ + admin_attributes := Entry{ "objectClass": []string{"simpleSecurityObject", "organizationalRole"}, "description": "LDAP administrator", "cn": "admin", @@ -164,7 +221,7 @@ func (server *Server) init() error { return nil } -func (server *Server) addElements(dn string, attrs Attributes) error { +func (server *Server) addElements(dn string, attrs Entry) error { prefix := dnToConsul(dn) for k, v := range attrs { json, err := json.Marshal(v) @@ -184,7 +241,7 @@ func (server *Server) handleBind(s ldap.UserState, w ldap.ResponseWriter, m *lda state := s.(*State) r := m.GetBindRequest() - result_code, err := server.handleBindInternal(state, w, r) + result_code, err := server.handleBindInternal(state, w, &r) res := ldap.NewBindResponse(result_code) if err != nil { @@ -194,7 +251,7 @@ func (server *Server) handleBind(s ldap.UserState, w ldap.ResponseWriter, m *lda w.Write(res) } -func (server *Server) handleBindInternal(state *State, w ldap.ResponseWriter, r message.BindRequest) (int, error) { +func (server *Server) handleBindInternal(state *State, w ldap.ResponseWriter, r *message.BindRequest) (int, error) { pair, _, err := server.kv.Get(dnToConsul(string(r.Name()))+"/attribute=userpassword", nil) if err != nil { @@ -219,3 +276,146 @@ func (server *Server) handleBindInternal(state *State, w ldap.ResponseWriter, r return ldap.LDAPResultInvalidCredentials, nil } } + +func (server *Server) handleSearch(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) { + state := s.(*State) + r := m.GetSearchRequest() + + code, err := server.handleSearchInternal(state, w, &r) + + res := ldap.NewResponse(code) + if err != nil { + res.SetDiagnosticMessage(err.Error()) + } + w.Write(message.SearchResultDone(res)) +} + +func (server *Server) handleSearchInternal(state *State, w ldap.ResponseWriter, r *message.SearchRequest) (int, error) { + + log.Printf("-- SEARCH REQUEST: --") + log.Printf("Request BaseDn=%s", r.BaseObject()) + log.Printf("Request Filter=%s", r.Filter()) + log.Printf("Request FilterString=%s", r.FilterString()) + log.Printf("Request Attributes=%s", r.Attributes()) + log.Printf("Request TimeLimit=%d", r.TimeLimit().Int()) + + // TODO check authorizations + + basePath := dnToConsul(string(r.BaseObject())) + "/" + data, _, err := server.kv.List(basePath, nil) + if err != nil { + return ldap.LDAPResultOperationsError, err + } + + entries, err := parseConsulResult(data) + if err != nil { + return ldap.LDAPResultOperationsError, err + } + log.Printf("in %s: %#v", basePath, data) + log.Printf("%#v", entries) + + for dn, entry := range entries { + // TODO filter out if no permission to read this + matched, err := applyFilter(entry, r.Filter()) + if err != nil { + return ldap.LDAPResultOperationsError, err + } + if !matched { + continue + } + + e := ldap.NewSearchResultEntry(dn) + for attr, val := range entry { + // If attribute is not in request, exclude it from returned entry + if len(r.Attributes()) > 0 { + found := false + for _, need := range r.Attributes() { + if string(need) == attr { + found = true + break + } + } + if !found { + continue + } + } + // Send result + if val_str, ok := val.(string); ok { + e.AddAttribute(message.AttributeDescription(attr), + message.AttributeValue(val_str)) + } else if val_strlist, ok := val.([]string); ok { + for _, v := range val_strlist { + e.AddAttribute(message.AttributeDescription(attr), + message.AttributeValue(v)) + } + } else { + panic(fmt.Sprintf("Invalid value: %#v", val)) + } + } + w.Write(e) + } + + return ldap.LDAPResultSuccess, nil +} + +func applyFilter(entry Entry, filter message.Filter) (bool, error) { + if fAnd, ok := filter.(message.FilterAnd); ok { + for _, cond := range fAnd { + res, err := applyFilter(entry, cond) + if err != nil { + return false, err + } + if !res { + return false, nil + } + } + return true, nil + } else if fOr, ok := filter.(message.FilterOr); ok { + for _, cond := range fOr { + res, err := applyFilter(entry, cond) + if err != nil { + return false, err + } + if res { + return true, nil + } + } + return false, nil + } else if fNot, ok := filter.(message.FilterNot); ok { + res, err := applyFilter(entry, fNot.Filter) + if err != nil { + return false, err + } + return !res, nil + } else if fPresent, ok := filter.(message.FilterPresent); ok { + what := string(fPresent) + log.Printf("Present filter: %s", what) + if _, ok := entry[what]; ok { + return true, nil + } + return false, nil + } else if fEquality, ok := filter.(message.FilterEqualityMatch); ok { + desc := string(fEquality.AttributeDesc()) + target := string(fEquality.AssertionValue()) + if value, ok := entry[desc]; ok { + if vstr, ok := value.(string); ok { + // If we have one value for the key, match exactly + return vstr == target, nil + } else if vlist, ok := value.([]string); ok { + // If we have several values for the key, one must match + for _, val := range vlist { + if val == target { + return true, nil + } + } + return false, nil + } else { + panic(fmt.Sprintf("Invalid value: %#v", value)) + } + } else { + return false, nil + } + } else { + return false, fmt.Errorf("Unsupported filter: %#v %T", filter, filter) + } +}