diff --git a/main.go b/main.go index d905d72..aedcfb7 100644 --- a/main.go +++ b/main.go @@ -16,7 +16,11 @@ import ( message "github.com/vjeantet/goldap/message" ) -func dnToConsul(dn string) string { +func dnToConsul(dn string) (string, error) { + if strings.Contains(dn, "/") { + return "", fmt.Errorf("DN %s contains a /", dn) + } + rdns := strings.Split(dn, ",") // Reverse rdns @@ -24,7 +28,7 @@ func dnToConsul(dn string) string { rdns[i], rdns[j] = rdns[j], rdns[i] } - return strings.Join(rdns, "/") + return strings.Join(rdns, "/"), nil } func consulToDN(pair *consul.KVPair) (string, string, []byte) { @@ -173,7 +177,12 @@ func main() { } func (server *Server) init() error { - pair, _, err := server.kv.Get(dnToConsul(server.config.Suffix)+"/attribute=objectClass", nil) + path, err := dnToConsul(server.config.Suffix) + if err != nil { + return err + } + + pair, _, err := server.kv.Get(path+"/attribute=objectClass", nil) if err != nil { return err } @@ -227,7 +236,11 @@ func (server *Server) init() error { } func (server *Server) addElements(dn string, attrs Entry) error { - prefix := dnToConsul(dn) + prefix, err := dnToConsul(dn) + if err != nil { + return err + } + for k, v := range attrs { json, err := json.Marshal(v) if err != nil { @@ -243,7 +256,12 @@ func (server *Server) addElements(dn string, attrs Entry) error { } func (server *Server) getAttribute(dn string, attr string) ([]string, error) { - pair, _, err := server.kv.Get(dnToConsul(dn) + "/attribute=" + attr, nil) + path, err := dnToConsul(dn) + if err != nil { + return nil, err + } + + pair, _, err := server.kv.Get(path + "/attribute=" + attr, nil) if err != nil { return nil, err } @@ -256,9 +274,12 @@ func (server *Server) getAttribute(dn string, attr string) ([]string, error) { } func (server *Server) objectExists(dn string) (bool, error) { - prefix := dnToConsul(dn) + "/" + prefix, err := dnToConsul(dn) + if err != nil { + return false, err + } - data, _, err := server.kv.List(prefix, nil) + data, _, err := server.kv.List(prefix + "/", nil) if err != nil { return false, err } @@ -343,9 +364,12 @@ func (server *Server) handleSearchInternal(state *State, w ldap.ResponseWriter, if err != nil { return ldap.LDAPResultInvalidDNSyntax, err } - basePath := dnToConsul(baseObject) + "/" + basePath, err := dnToConsul(baseObject) + if err != nil { + return ldap.LDAPResultInvalidDNSyntax, err + } - data, _, err := server.kv.List(basePath, nil) + data, _, err := server.kv.List(basePath + "/", nil) if err != nil { return ldap.LDAPResultOperationsError, err } @@ -354,7 +378,7 @@ func (server *Server) handleSearchInternal(state *State, w ldap.ResponseWriter, if err != nil { return ldap.LDAPResultOperationsError, err } - log.Printf("in %s: %#v", basePath, data) + log.Printf("in %s: %#v", basePath + "/", data) log.Printf("%#v", entries) for dn, entry := range entries { @@ -631,8 +655,12 @@ func (server *Server) handleDeleteInternal(state *State, r *message.DelRequest) // TODO check user for permissions to write dn // Check that this LDAP entry exists and has no children - path := dnToConsul(dn) + "/" - items, _, err := server.kv.List(path, nil) + path, err := dnToConsul(dn) + if err != nil { + return ldap.LDAPResultInvalidDNSyntax, err + } + + items, _, err := server.kv.List(path + "/", nil) if err != nil { return ldap.LDAPResultOperationsError, err } @@ -655,7 +683,7 @@ func (server *Server) handleDeleteInternal(state *State, r *message.DelRequest) } // Delete the LDAP entry - _, err = server.kv.DeleteTree(path, nil) + _, err = server.kv.DeleteTree(path + "/", nil) if err != nil { return ldap.LDAPResultOperationsError, err } @@ -712,7 +740,12 @@ func (server *Server) handleModifyInternal(state *State, r *message.ModifyReques // TODO check user for permissions to write dn // Retrieve previous values (by the way, check object exists) - items, _, err := server.kv.List(dnToConsul(dn) + "/attribute=", nil) + path, err := dnToConsul(dn) + if err != nil { + return ldap.LDAPResultInvalidDNSyntax, err + } + + items, _, err := server.kv.List(path + "/attribute=", nil) if err != nil { return ldap.LDAPResultOperationsError, err }