diff --git a/main.go b/main.go index f717382..aeeb421 100644 --- a/main.go +++ b/main.go @@ -58,26 +58,12 @@ func parseConsulResult(data []*consul.KVPair) (map[string]Entry, error) { if _, exists := aggregator[dn]; !exists { aggregator[dn] = Entry{} } - var value interface{} + var value []string err := json.Unmarshal(val, &value) if err != nil { return nil, err } - if vlist, ok := value.([]interface{}); ok { - vlist2 := []string{} - for _, v := range vlist { - if vstr, ok := v.(string); ok { - vlist2 = append(vlist2, vstr) - } else { - return nil, fmt.Errorf("Not a string: %#v", v) - } - } - aggregator[dn][attr] = vlist2 - } else if vstr, ok := value.(string); ok { - aggregator[dn][attr] = vstr - } else { - return nil, fmt.Errorf("Not a string or a list of strings: %#v", value) - } + aggregator[dn][attr] = value } return aggregator, nil @@ -119,7 +105,7 @@ type State struct { bindDn string } -type Entry map[string]interface{} +type Entry map[string][]string func main() { //ldap logger @@ -152,6 +138,7 @@ func main() { routes := ldap.NewRouteMux() routes.Bind(gobottin.handleBind) routes.Search(gobottin.handleSearch) + routes.Add(gobottin.handleAdd) ldapserver.Handle(routes) // listen on 10389 @@ -179,13 +166,13 @@ func (server *Server) init() error { base_attributes := Entry{ "objectClass": []string{"top", "dcObject", "organization"}, - "structuralObjectClass": "Organization", + "structuralObjectClass": []string{"Organization"}, } suffix_dn, err := parseDN(server.config.Suffix) if err != nil { return err } - base_attributes[suffix_dn[0].Type] = suffix_dn[0].Value + base_attributes[suffix_dn[0].Type] = []string{suffix_dn[0].Value} err = server.addElements(server.config.Suffix, base_attributes) if err != nil { @@ -200,10 +187,10 @@ func (server *Server) init() error { admin_dn := "cn=admin," + server.config.Suffix admin_attributes := Entry{ "objectClass": []string{"simpleSecurityObject", "organizationalRole"}, - "description": "LDAP administrator", - "cn": "admin", - "userpassword": admin_pass_hash, - "structuralObjectClass": "organizationalRole", + "description": []string{"LDAP administrator"}, + "cn": []string{"admin"}, + "userpassword": []string{admin_pass_hash}, + "structuralObjectClass": []string{"organizationalRole"}, "permissions": []string{"read", "write"}, } @@ -241,7 +228,7 @@ func (server *Server) handleBind(s ldap.UserState, w ldap.ResponseWriter, m *lda state := s.(*State) r := m.GetBindRequest() - result_code, err := server.handleBindInternal(state, w, &r) + result_code, err := server.handleBindInternal(state, &r) res := ldap.NewBindResponse(result_code) if err != nil { @@ -251,7 +238,7 @@ func (server *Server) handleBind(s ldap.UserState, w ldap.ResponseWriter, m *lda w.Write(res) } -func (server *Server) handleBindInternal(state *State, w ldap.ResponseWriter, r *message.BindRequest) (int, error) { +func (server *Server) handleBindInternal(state *State, r *message.BindRequest) (int, error) { pair, _, err := server.kv.Get(dnToConsul(string(r.Name()))+"/attribute=userpassword", nil) if err != nil { @@ -340,16 +327,9 @@ func (server *Server) handleSearchInternal(state *State, w ldap.ResponseWriter, } } // Send result - if val_str, ok := val.(string); ok { + for _, v := range val { e.AddAttribute(message.AttributeDescription(attr), - message.AttributeValue(val_str)) - } else if val_strlist, ok := val.([]string); ok { - for _, v := range val_strlist { - e.AddAttribute(message.AttributeDescription(attr), - message.AttributeValue(v)) - } - } else { - panic(fmt.Sprintf("Invalid value: %#v", val)) + message.AttributeValue(v)) } } w.Write(e) @@ -402,20 +382,12 @@ func applyFilter(entry Entry, filter message.Filter) (bool, error) { // Case insensitive attribute search for entry_desc, value := range entry { if strings.EqualFold(entry_desc, desc) { - if vstr, ok := value.(string); ok { - // If we have one value for the key, match exactly - return vstr == target, nil - } else if vlist, ok := value.([]string); ok { - // If we have several values for the key, one must match - for _, val := range vlist { - if val == target { - return true, nil - } + for _, val := range value { + if val == target { + return true, nil } - return false, nil - } else { - panic(fmt.Sprintf("Invalid value: %#v", value)) } + return false, nil } } return false, nil @@ -423,3 +395,50 @@ func applyFilter(entry Entry, filter message.Filter) (bool, error) { return false, fmt.Errorf("Unsupported filter: %#v %T", filter, filter) } } + +func (server *Server) handleAdd(s ldap.UserState, w ldap.ResponseWriter, m *ldap.Message) { + state := s.(*State) + r := m.GetAddRequest() + + code, err := server.handleAddInternal(state, &r) + + res := ldap.NewResponse(code) + if err != nil { + res.SetDiagnosticMessage(err.Error()) + } + w.Write(message.AddResponse(res)) +} + +func (server *Server) handleAddInternal(state *State, r *message.AddRequest) (int, error) { + dn := string(r.Entry()) + + prefix := dnToConsul(dn) + "/" + + data, _, err := server.kv.List(prefix, nil) + if err != nil { + return ldap.LDAPResultOperationsError, err + } + if len(data) > 0 { + return ldap.LDAPResultEntryAlreadyExists, nil + } + + // TODO check permissions + + entry := Entry{} + for _, attribute := range r.Attributes() { + key := string(attribute.Type_()) + vals_str := []string{} + for _, val := range attribute.Vals() { + vals_str = append(vals_str, string(val)) + } + entry[key] = vals_str + } + + err = server.addElements(dn, entry) + if err != nil { + return ldap.LDAPResultOperationsError, err + } + + return ldap.LDAPResultSuccess, nil +} +