Implement add with group membership

This commit is contained in:
Alex 2020-01-19 19:10:38 +01:00
parent 2bfb6b4ced
commit b238a6eb0e

127
main.go
View file

@ -46,11 +46,27 @@ func consulToDN(pair *consul.KVPair) (string, string, []byte) {
return dn, "", nil return dn, "", nil
} }
func parseValue(value []byte) ([]string, error) {
val := []string{}
err := json.Unmarshal(value, &val)
if err == nil {
return val, nil
}
val2 := ""
err = json.Unmarshal(value, &val2)
if err == nil {
return []string{val2}, nil
}
return nil, fmt.Errorf("Not a string or list of strings: %s", value)
}
func parseConsulResult(data []*consul.KVPair) (map[string]Entry, error) { func parseConsulResult(data []*consul.KVPair) (map[string]Entry, error) {
aggregator := map[string]Entry{} aggregator := map[string]Entry{}
for _, kv := range data { for _, kv := range data {
log.Printf("%s %s", kv.Key, string(kv.Value)) log.Printf("(parseConsulResult) %s %s", kv.Key, string(kv.Value))
dn, attr, val := consulToDN(kv) dn, attr, val := consulToDN(kv)
if attr == "" || val == nil { if attr == "" || val == nil {
continue continue
@ -58,8 +74,7 @@ func parseConsulResult(data []*consul.KVPair) (map[string]Entry, error) {
if _, exists := aggregator[dn]; !exists { if _, exists := aggregator[dn]; !exists {
aggregator[dn] = Entry{} aggregator[dn] = Entry{}
} }
var value []string value, err := parseValue(val)
err := json.Unmarshal(val, &value)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -224,6 +239,29 @@ func (server *Server) addElements(dn string, attrs Entry) error {
return nil return nil
} }
func (server *Server) getAttribute(dn string, attr string) ([]string, error) {
pair, _, err := server.kv.Get(dnToConsul(dn) + "/attribute=" + attr, nil)
if err != nil {
return nil, err
}
if pair == nil {
return nil, nil
}
return parseValue(pair.Value)
}
func (server *Server) objectExists(dn string) (bool, error) {
prefix := dnToConsul(dn) + "/"
data, _, err := server.kv.List(prefix, nil)
if err != nil {
return false, err
}
return len(data) > 0, nil
}
func (server *Server) handleBind(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) { func (server *Server) handleBind(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) {
state := s.(*State) state := s.(*State)
r := m.GetBindRequest() r := m.GetBindRequest()
@ -240,29 +278,24 @@ func (server *Server) handleBind(s ldap.UserState, w ldap.ResponseWriter, m *lda
func (server *Server) handleBindInternal(state *State, r *message.BindRequest) (int, error) { func (server *Server) handleBindInternal(state *State, r *message.BindRequest) (int, error) {
pair, _, err := server.kv.Get(dnToConsul(string(r.Name()))+"/attribute=userpassword", nil) passwd, err := server.getAttribute(string(r.Name()), "userpassword")
if err != nil { if err != nil {
return ldap.LDAPResultOperationsError, err return ldap.LDAPResultOperationsError, err
} }
if pair == nil { if passwd == nil {
return ldap.LDAPResultNoSuchObject, nil return ldap.LDAPResultNoSuchObject, nil
} }
hash := "" for _, hash := range passwd {
err = json.Unmarshal(pair.Value, &hash)
if err != nil {
return ldap.LDAPResultOperationsError, err
}
valid := SSHAMatches(hash, []byte(r.AuthenticationSimple())) valid := SSHAMatches(hash, []byte(r.AuthenticationSimple()))
if valid { if valid {
state.bindDn = string(r.Name()) state.bindDn = string(r.Name())
return ldap.LDAPResultSuccess, nil return ldap.LDAPResultSuccess, nil
} else {
return ldap.LDAPResultInvalidCredentials, nil
} }
} }
return ldap.LDAPResultInvalidCredentials, nil
}
func (server *Server) handleSearch(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) { func (server *Server) handleSearch(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) {
state := s.(*State) state := s.(*State)
@ -287,9 +320,25 @@ func (server *Server) handleSearchInternal(state *State, w ldap.ResponseWriter,
log.Printf("Request TimeLimit=%d", r.TimeLimit().Int()) log.Printf("Request TimeLimit=%d", r.TimeLimit().Int())
// TODO check authorizations // TODO check authorizations
baseObject := dnToConsul(string(r.BaseObject()))
minimalBaseObject := dnToConsul(server.config.Suffix)
basePath := dnToConsul(string(r.BaseObject())) + "/" if len(baseObject) <= len(minimalBaseObject) {
data, _, err := server.kv.List(basePath, nil) if baseObject != minimalBaseObject[:len(baseObject)] {
return ldap.LDAPResultInvalidDNSyntax, fmt.Errorf(
"Only handling search results under DN=%s",
server.config.Suffix)
}
baseObject = minimalBaseObject
} else {
if baseObject[:len(minimalBaseObject)] != minimalBaseObject {
return ldap.LDAPResultInvalidDNSyntax, fmt.Errorf(
"Only handling search results under DN=%s",
server.config.Suffix)
}
}
data, _, err := server.kv.List(baseObject + "/", nil)
if err != nil { if err != nil {
return ldap.LDAPResultOperationsError, err return ldap.LDAPResultOperationsError, err
} }
@ -298,7 +347,7 @@ func (server *Server) handleSearchInternal(state *State, w ldap.ResponseWriter,
if err != nil { if err != nil {
return ldap.LDAPResultOperationsError, err return ldap.LDAPResultOperationsError, err
} }
log.Printf("in %s: %#v", basePath, data) log.Printf("in %s: %#v", baseObject + "/", data)
log.Printf("%#v", entries) log.Printf("%#v", entries)
for dn, entry := range entries { for dn, entry := range entries {
@ -412,25 +461,42 @@ func (server *Server) handleAdd(s ldap.UserState, w ldap.ResponseWriter, m *ldap
func (server *Server) handleAddInternal(state *State, r *message.AddRequest) (int, error) { func (server *Server) handleAddInternal(state *State, r *message.AddRequest) (int, error) {
dn := string(r.Entry()) dn := string(r.Entry())
prefix := dnToConsul(dn) + "/" exists, err := server.objectExists(dn)
data, _, err := server.kv.List(prefix, nil)
if err != nil { if err != nil {
return ldap.LDAPResultOperationsError, err return ldap.LDAPResultOperationsError, err
} }
if len(data) > 0 { if exists {
return ldap.LDAPResultEntryAlreadyExists, nil return ldap.LDAPResultEntryAlreadyExists, nil
} }
// TODO check permissions // TODO check permissions
var members []string = nil
entry := Entry{} entry := Entry{}
for _, attribute := range r.Attributes() { for _, attribute := range r.Attributes() {
key := string(attribute.Type_()) key := string(attribute.Type_())
if strings.EqualFold(key, "memberOf") {
return ldap.LDAPResultObjectClassViolation, fmt.Errorf(
"memberOf cannot be defined directly, membership must be specified in the group itself")
}
vals_str := []string{} vals_str := []string{}
for _, val := range attribute.Vals() { for _, val := range attribute.Vals() {
vals_str = append(vals_str, string(val)) vals_str = append(vals_str, string(val))
} }
if strings.EqualFold(key, "member") {
members = vals_str
for _, member := range members {
exists, err = server.objectExists(member)
if err != nil {
return ldap.LDAPResultOperationsError, err
}
if !exists {
return ldap.LDAPResultNoSuchObject, fmt.Errorf(
"Cannot add %s to members, it does not exist!",
member)
}
}
}
entry[key] = vals_str entry[key] = vals_str
} }
@ -439,6 +505,25 @@ func (server *Server) handleAddInternal(state *State, r *message.AddRequest) (in
return ldap.LDAPResultOperationsError, err return ldap.LDAPResultOperationsError, err
} }
return ldap.LDAPResultSuccess, nil if members != nil {
for _, member := range members {
memberGroups, err := server.getAttribute(member, "memberOf")
if err != nil {
return ldap.LDAPResultOperationsError, err
}
if memberGroups == nil {
memberGroups = []string{}
} }
memberGroups = append(memberGroups, dn)
err = server.addElements(member, Entry{
"memberOf": memberGroups,
})
if err != nil {
return ldap.LDAPResultOperationsError, err
}
}
}
return ldap.LDAPResultSuccess, nil
}