Import
This commit is contained in:
276
ldap.go
Normal file
276
ldap.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
)
|
||||
|
||||
// LDAPAuth authenticates users against an LDAP directory discovered via DNS SRV.
|
||||
type LDAPAuth struct {
|
||||
domain string
|
||||
host string
|
||||
port int
|
||||
useTLS bool // true = LDAPS (TLS from the start); false = STARTTLS on port 389
|
||||
bindDN string // empty = anonymous bind
|
||||
bindPassword string
|
||||
}
|
||||
|
||||
// NewLDAPAuth discovers the LDAP server for domain via the standard _ldap._tcp
|
||||
// DNS SRV record. The connection mode is derived from the advertised port:
|
||||
// port 389 uses STARTTLS; anything else (typically 636) uses LDAPS.
|
||||
// _ldaps._tcp is not an IANA-registered SRV type and is not consulted.
|
||||
// bindDN and bindPassword are used for the search bind; both empty = anonymous.
|
||||
func NewLDAPAuth(domain, bindDN, bindPassword string) (*LDAPAuth, error) {
|
||||
host, port, err := discoverLDAP(domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &LDAPAuth{
|
||||
domain: domain,
|
||||
host: host,
|
||||
port: port,
|
||||
useTLS: port != 389,
|
||||
bindDN: bindDN,
|
||||
bindPassword: bindPassword,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewLDAPAuthFromURI creates an LDAPAuth from an explicit URI instead of DNS
|
||||
// SRV discovery. Supported schemes: ldap:// (STARTTLS) and ldaps:// (TLS).
|
||||
// Port defaults to 389 for ldap:// and 636 for ldaps:// if not specified.
|
||||
// domain is still required for base-DN derivation and UPN construction.
|
||||
// bindDN and bindPassword are used for the search bind; both empty = anonymous.
|
||||
func NewLDAPAuthFromURI(rawURI, domain, bindDN, bindPassword string) (*LDAPAuth, error) {
|
||||
u, err := url.Parse(rawURI)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid LDAP URI %q: %w", rawURI, err)
|
||||
}
|
||||
var useTLS bool
|
||||
var defaultPort int
|
||||
switch strings.ToLower(u.Scheme) {
|
||||
case "ldap":
|
||||
useTLS = false
|
||||
defaultPort = 389
|
||||
case "ldaps":
|
||||
useTLS = true
|
||||
defaultPort = 636
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported LDAP URI scheme %q (want ldap:// or ldaps://)", u.Scheme)
|
||||
}
|
||||
host := u.Hostname()
|
||||
port := defaultPort
|
||||
if ps := u.Port(); ps != "" {
|
||||
port, err = strconv.Atoi(ps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid port in LDAP URI %q: %w", rawURI, err)
|
||||
}
|
||||
}
|
||||
return &LDAPAuth{
|
||||
domain: domain,
|
||||
host: host,
|
||||
port: port,
|
||||
useTLS: useTLS,
|
||||
bindDN: bindDN,
|
||||
bindPassword: bindPassword,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func discoverLDAP(domain string) (host string, port int, err error) {
|
||||
_, addrs, err := net.LookupSRV("ldap", "tcp", domain)
|
||||
if err != nil || len(addrs) == 0 {
|
||||
return "", 0, fmt.Errorf("no _ldap._tcp SRV records found for %s", domain)
|
||||
}
|
||||
return strings.TrimSuffix(addrs[0].Target, "."), int(addrs[0].Port), nil
|
||||
}
|
||||
|
||||
func (a *LDAPAuth) connect() (*ldap.Conn, error) {
|
||||
addr := fmt.Sprintf("%s:%d", a.host, a.port)
|
||||
tlsConfig := &tls.Config{ServerName: a.host}
|
||||
if a.useTLS {
|
||||
dbg.Printf("LDAP: dialing TLS %s", addr)
|
||||
return ldap.DialTLS("tcp", addr, tlsConfig)
|
||||
}
|
||||
dbg.Printf("LDAP: dialing %s + STARTTLS", addr)
|
||||
conn, err := ldap.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := conn.StartTLS(tlsConfig); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("STARTTLS: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// searchBind binds with the configured service account, or anonymously if no
|
||||
// bind DN is configured.
|
||||
func (a *LDAPAuth) searchBind(conn *ldap.Conn) error {
|
||||
if a.bindDN == "" {
|
||||
return conn.UnauthenticatedBind("")
|
||||
}
|
||||
return conn.Bind(a.bindDN, a.bindPassword)
|
||||
}
|
||||
|
||||
// findUserDN searches for the user entry and returns its full DN.
|
||||
func (a *LDAPAuth) findUserDN(conn *ldap.Conn, username string) (string, error) {
|
||||
search := ldap.NewSearchRequest(
|
||||
domainToBaseDN(a.domain),
|
||||
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases,
|
||||
1, 10, false,
|
||||
fmt.Sprintf("(|(sAMAccountName=%s)(uid=%s))",
|
||||
ldap.EscapeFilter(username),
|
||||
ldap.EscapeFilter(username)),
|
||||
[]string{"dn"},
|
||||
nil,
|
||||
)
|
||||
result, err := conn.Search(search)
|
||||
if err != nil || len(result.Entries) == 0 {
|
||||
return "", fmt.Errorf("user not found")
|
||||
}
|
||||
return result.Entries[0].DN, nil
|
||||
}
|
||||
|
||||
// Authenticate verifies username/password against LDAP and returns the user's
|
||||
// group CNs (used as Kubernetes RBAC groups via the certificate's Organisation field).
|
||||
// Returns a generic error on bad credentials to avoid user-enumeration.
|
||||
//
|
||||
// Flow: search bind → find user DN → user bind (verify password) → search bind → group lookup.
|
||||
func (a *LDAPAuth) Authenticate(username, password string) (groups []string, err error) {
|
||||
if username == "" || password == "" {
|
||||
return nil, fmt.Errorf("username and password required")
|
||||
}
|
||||
|
||||
conn, err := a.connect()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("LDAP connect: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Bind as service account (or anonymously) to locate the user's DN.
|
||||
if err := a.searchBind(conn); err != nil {
|
||||
return nil, fmt.Errorf("LDAP search bind failed: %w", err)
|
||||
}
|
||||
|
||||
dbg.Printf("LDAP: searching for user %q", username)
|
||||
userDN, err := a.findUserDN(conn, username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
// Verify the password by binding as the user.
|
||||
dbg.Printf("LDAP: binding as %s", userDN)
|
||||
if err := conn.Bind(userDN, password); err != nil {
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
dbg.Printf("LDAP: bind OK for %s", userDN)
|
||||
|
||||
// Re-bind as service account for group lookup — the user may lack read access.
|
||||
if err := a.searchBind(conn); err != nil {
|
||||
dbg.Printf("LDAP: re-bind for group lookup failed: %v — skipping groups", err)
|
||||
return nil, nil // auth succeeded; group lookup is best-effort
|
||||
}
|
||||
|
||||
groups = a.lookupGroups(conn, username, userDN)
|
||||
dbg.Printf("LDAP: groups for %s: %v", username, groups)
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// lookupGroups searches for group memberships using two strategies so it works
|
||||
// with both Active Directory (memberOf attribute) and POSIX/OpenLDAP layouts
|
||||
// (groupOfNames / posixGroup with member or memberUid attributes).
|
||||
func (a *LDAPAuth) lookupGroups(conn *ldap.Conn, username, userDN string) []string {
|
||||
baseDN := domainToBaseDN(a.domain)
|
||||
var groups []string
|
||||
|
||||
// AD-style: memberOf attribute on the user's own entry.
|
||||
memberOfSearch := ldap.NewSearchRequest(
|
||||
userDN,
|
||||
ldap.ScopeBaseObject, ldap.NeverDerefAliases,
|
||||
1, 10, false,
|
||||
"(objectClass=*)",
|
||||
[]string{"memberOf"},
|
||||
nil,
|
||||
)
|
||||
if result, err := conn.Search(memberOfSearch); err == nil && len(result.Entries) > 0 {
|
||||
for _, groupDN := range result.Entries[0].GetAttributeValues("memberOf") {
|
||||
if cn := cnFromDN(groupDN); cn != "" {
|
||||
groups = append(groups, cn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// POSIX/OpenLDAP style: search for groups that list this user as a member.
|
||||
groupSearch := ldap.NewSearchRequest(
|
||||
baseDN,
|
||||
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases,
|
||||
0, 10, false,
|
||||
fmt.Sprintf("(|(member=%s)(memberUid=%s))",
|
||||
ldap.EscapeFilter(userDN),
|
||||
ldap.EscapeFilter(username)),
|
||||
[]string{"cn"},
|
||||
nil,
|
||||
)
|
||||
if groupResult, err := conn.Search(groupSearch); err == nil {
|
||||
for _, entry := range groupResult.Entries {
|
||||
if cn := entry.GetAttributeValue("cn"); cn != "" && !containsStr(groups, cn) {
|
||||
groups = append(groups, cn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return groups
|
||||
}
|
||||
|
||||
// LookupGroups searches for the user's groups using the configured bind credentials
|
||||
// (or anonymously). Used after Kerberos authentication to populate Kubernetes RBAC
|
||||
// group memberships.
|
||||
func (a *LDAPAuth) LookupGroups(username string) []string {
|
||||
conn, err := a.connect()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer conn.Close()
|
||||
if err := a.searchBind(conn); err != nil {
|
||||
return nil
|
||||
}
|
||||
userDN, err := a.findUserDN(conn, username)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return a.lookupGroups(conn, username, userDN)
|
||||
}
|
||||
|
||||
// domainToBaseDN converts "example.com" to "dc=example,dc=com".
|
||||
func domainToBaseDN(domain string) string {
|
||||
parts := strings.Split(domain, ".")
|
||||
dcs := make([]string, len(parts))
|
||||
for i, p := range parts {
|
||||
dcs[i] = "dc=" + p
|
||||
}
|
||||
return strings.Join(dcs, ",")
|
||||
}
|
||||
|
||||
// cnFromDN extracts the CN value from the first CN= component of an LDAP DN.
|
||||
func cnFromDN(dn string) string {
|
||||
for _, part := range strings.Split(dn, ",") {
|
||||
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(part)), "cn=") {
|
||||
return strings.TrimSpace(part)[3:]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func containsStr(ss []string, s string) bool {
|
||||
for _, v := range ss {
|
||||
if v == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
Reference in New Issue
Block a user