From b238a6eb0ee31da11a8a78a99bcb017035011e1f Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Sun, 19 Jan 2020 19:10:38 +0100 Subject: [PATCH] Implement add with group membership --- main.go | 135 +++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 110 insertions(+), 25 deletions(-) diff --git a/main.go b/main.go index aeeb421..92c0712 100644 --- a/main.go +++ b/main.go @@ -46,11 +46,27 @@ func consulToDN(pair *consul.KVPair) (string, string, []byte) { return dn, "", nil } +func parseValue(value []byte) ([]string, error) { + val := []string{} + err := json.Unmarshal(value, &val) + if err == nil { + return val, nil + } + + val2 := "" + err = json.Unmarshal(value, &val2) + if err == nil { + return []string{val2}, nil + } + + return nil, fmt.Errorf("Not a string or list of strings: %s", value) +} + func parseConsulResult(data []*consul.KVPair) (map[string]Entry, error) { aggregator := map[string]Entry{} for _, kv := range data { - log.Printf("%s %s", kv.Key, string(kv.Value)) + log.Printf("(parseConsulResult) %s %s", kv.Key, string(kv.Value)) dn, attr, val := consulToDN(kv) if attr == "" || val == nil { continue @@ -58,8 +74,7 @@ func parseConsulResult(data []*consul.KVPair) (map[string]Entry, error) { if _, exists := aggregator[dn]; !exists { aggregator[dn] = Entry{} } - var value []string - err := json.Unmarshal(val, &value) + value, err := parseValue(val) if err != nil { return nil, err } @@ -224,6 +239,29 @@ func (server *Server) addElements(dn string, attrs Entry) error { return nil } +func (server *Server) getAttribute(dn string, attr string) ([]string, error) { + pair, _, err := server.kv.Get(dnToConsul(dn) + "/attribute=" + attr, nil) + if err != nil { + return nil, err + } + + if pair == nil { + return nil, nil + } + + return parseValue(pair.Value) +} + +func (server *Server) objectExists(dn string) (bool, error) { + prefix := dnToConsul(dn) + "/" + + data, _, err := server.kv.List(prefix, nil) + if err != nil { + return false, err + } + return len(data) > 0, nil +} + func (server *Server) handleBind(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) { state := s.(*State) r := m.GetBindRequest() @@ -240,28 +278,23 @@ func (server *Server) handleBind(s ldap.UserState, w ldap.ResponseWriter, m *lda func (server *Server) handleBindInternal(state *State, r *message.BindRequest) (int, error) { - pair, _, err := server.kv.Get(dnToConsul(string(r.Name()))+"/attribute=userpassword", nil) + passwd, err := server.getAttribute(string(r.Name()), "userpassword") if err != nil { return ldap.LDAPResultOperationsError, err } - if pair == nil { + if passwd == nil { return ldap.LDAPResultNoSuchObject, nil } - hash := "" - err = json.Unmarshal(pair.Value, &hash) - if err != nil { - return ldap.LDAPResultOperationsError, err - } - - valid := SSHAMatches(hash, []byte(r.AuthenticationSimple())) - if valid { - state.bindDn = string(r.Name()) - return ldap.LDAPResultSuccess, nil - } else { - return ldap.LDAPResultInvalidCredentials, nil + for _, hash := range passwd { + valid := SSHAMatches(hash, []byte(r.AuthenticationSimple())) + if valid { + state.bindDn = string(r.Name()) + return ldap.LDAPResultSuccess, nil + } } + return ldap.LDAPResultInvalidCredentials, nil } func (server *Server) handleSearch(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) { @@ -287,9 +320,25 @@ func (server *Server) handleSearchInternal(state *State, w ldap.ResponseWriter, log.Printf("Request TimeLimit=%d", r.TimeLimit().Int()) // TODO check authorizations + baseObject := dnToConsul(string(r.BaseObject())) + minimalBaseObject := dnToConsul(server.config.Suffix) - basePath := dnToConsul(string(r.BaseObject())) + "/" - data, _, err := server.kv.List(basePath, nil) + if len(baseObject) <= len(minimalBaseObject) { + if baseObject != minimalBaseObject[:len(baseObject)] { + return ldap.LDAPResultInvalidDNSyntax, fmt.Errorf( + "Only handling search results under DN=%s", + server.config.Suffix) + } + baseObject = minimalBaseObject + } else { + if baseObject[:len(minimalBaseObject)] != minimalBaseObject { + return ldap.LDAPResultInvalidDNSyntax, fmt.Errorf( + "Only handling search results under DN=%s", + server.config.Suffix) + } + } + + data, _, err := server.kv.List(baseObject + "/", nil) if err != nil { return ldap.LDAPResultOperationsError, err } @@ -298,7 +347,7 @@ func (server *Server) handleSearchInternal(state *State, w ldap.ResponseWriter, if err != nil { return ldap.LDAPResultOperationsError, err } - log.Printf("in %s: %#v", basePath, data) + log.Printf("in %s: %#v", baseObject + "/", data) log.Printf("%#v", entries) for dn, entry := range entries { @@ -412,25 +461,42 @@ func (server *Server) handleAdd(s ldap.UserState, w ldap.ResponseWriter, m *ldap func (server *Server) handleAddInternal(state *State, r *message.AddRequest) (int, error) { dn := string(r.Entry()) - prefix := dnToConsul(dn) + "/" - - data, _, err := server.kv.List(prefix, nil) + exists, err := server.objectExists(dn) if err != nil { return ldap.LDAPResultOperationsError, err } - if len(data) > 0 { + if exists { return ldap.LDAPResultEntryAlreadyExists, nil } // TODO check permissions + var members []string = nil entry := Entry{} for _, attribute := range r.Attributes() { key := string(attribute.Type_()) + if strings.EqualFold(key, "memberOf") { + return ldap.LDAPResultObjectClassViolation, fmt.Errorf( + "memberOf cannot be defined directly, membership must be specified in the group itself") + } vals_str := []string{} for _, val := range attribute.Vals() { vals_str = append(vals_str, string(val)) } + if strings.EqualFold(key, "member") { + members = vals_str + for _, member := range members { + exists, err = server.objectExists(member) + if err != nil { + return ldap.LDAPResultOperationsError, err + } + if !exists { + return ldap.LDAPResultNoSuchObject, fmt.Errorf( + "Cannot add %s to members, it does not exist!", + member) + } + } + } entry[key] = vals_str } @@ -439,6 +505,25 @@ func (server *Server) handleAddInternal(state *State, r *message.AddRequest) (in return ldap.LDAPResultOperationsError, err } + if members != nil { + for _, member := range members { + memberGroups, err := server.getAttribute(member, "memberOf") + if err != nil { + return ldap.LDAPResultOperationsError, err + } + if memberGroups == nil { + memberGroups = []string{} + } + + memberGroups = append(memberGroups, dn) + err = server.addElements(member, Entry{ + "memberOf": memberGroups, + }) + if err != nil { + return ldap.LDAPResultOperationsError, err + } + } + } + return ldap.LDAPResultSuccess, nil } -