package main import ( "crypto/tls" "flag" "fmt" "io" "log" "net" "net/http" "os" "os/signal" "strings" "syscall" "time" ) // dbg is a package-level debug logger, active only when --debug is passed. // Defaults to discarding output; main() points it at stderr when --debug is set. var dbg = log.New(io.Discard, "[ward] ", 0) func main() { // ── Subcommand dispatch ─────────────────────────────────────────────────── // "ward credential ..." acts as a kubectl exec credential plugin. // Handle it before parsing server flags so flag.Parse() doesn't choke on // credential-specific flags. if len(os.Args) > 1 && os.Args[1] == "credential" { os.Exit(runCredential(os.Args[2:])) } // ── Server flags ────────────────────────────────────────────────────────── fqdn := detectFQDN() domain := domainPart(fqdn) realm := strings.ToUpper(domain) var ( addr = flag.String("addr", ":8443", "Listen address") tlsCert = flag.String("tls-cert", fmt.Sprintf("/etc/letsencrypt/live/%s/fullchain.pem", fqdn), "TLS certificate file (Let's Encrypt fullchain)") tlsKey = flag.String("tls-key", fmt.Sprintf("/etc/letsencrypt/live/%s/privkey.pem", fqdn), "TLS private key file") k3sServer = flag.String("k3s-server", "", "k3s API server URL written into returned kubeconfigs (required, e.g. https://k3s.example.com:6443)") serverCACert = flag.String("server-ca-cert", "/var/lib/rancher/k3s/server/tls/server-ca.crt", "k3s server CA certificate (embedded in returned kubeconfig)") clientCACert = flag.String("client-ca-cert", "/var/lib/rancher/k3s/server/tls/client-ca.crt", "k3s client CA certificate (signs user certs)") clientCAKey = flag.String("client-ca-key", "/var/lib/rancher/k3s/server/tls/client-ca.key", "k3s client CA key") certDuration = flag.Duration("cert-duration", 24*time.Hour, "Validity period of generated client certificates") clusterName = flag.String("cluster-name", firstLabel(domain), "Cluster/context name written into generated kubeconfigs") // LDAP (opt-in; auto-discovered via DNS SRV unless --ldap-uri is given) ldapDomain = flag.String("domain", domain, "Domain for LDAP SRV discovery and Kerberos realm derivation") ldapOn = flag.Bool("ldap", false, "Enable LDAP authentication (auto-discovered via DNS SRV)") ldapURI = flag.String("ldap-uri", "", "LDAP server URI, e.g. ldaps://ldap.example.com (implies --ldap; overrides DNS SRV)") ldapBindDN = flag.String("ldap-bind-dn", os.Getenv("WARD_LDAP_BIND_DN"), "LDAP bind DN for search (default: anonymous; env: WARD_LDAP_BIND_DN)") ldapBindPassword = flag.String("ldap-bind-password", os.Getenv("WARD_LDAP_BIND_PASSWORD"), "LDAP bind password (env: WARD_LDAP_BIND_PASSWORD; caution: visible in ps output)") // Kerberos SPNEGO (opt-in) kerberosOn = flag.Bool("kerberos", false, "Enable Kerberos SPNEGO authentication (Authorization: Negotiate)") keytabPath = flag.String("keytab", "/etc/krb5.keytab", "Kerberos service keytab path") spn = flag.String("spn", "HTTP/"+fqdn, fmt.Sprintf("Kerberos service principal name (SPN)\n\t\t\t(default realm %s — create with: kadmin: addprinc -randkey HTTP/%s@%s)", realm, fqdn, realm)) // htpasswd (opt-in) htpasswdPath = flag.String("htpasswd", "", "Path to an Apache-compatible htpasswd file (bcrypt recommended: htpasswd -B -c file user)") // Debug logging debugFlag = flag.Bool("debug", os.Getenv("WARD_DEBUG") != "", "Enable verbose debug logging (also: $WARD_DEBUG=1)") ) flag.Parse() if *debugFlag { dbg = log.New(os.Stderr, "[ward] ", log.Ltime) dbg.Print("debug logging enabled") } if *k3sServer == "" { log.Fatal("--k3s-server is required (e.g. https://k3s.example.com:6443)") } // ── LDAP ────────────────────────────────────────────────────────────────── var ldapAuth *LDAPAuth if *ldapURI != "" || *ldapOn { var la *LDAPAuth var err error if *ldapURI != "" { la, err = NewLDAPAuthFromURI(*ldapURI, *ldapDomain, *ldapBindDN, *ldapBindPassword) } else { la, err = NewLDAPAuth(*ldapDomain, *ldapBindDN, *ldapBindPassword) } if err != nil { log.Fatalf("LDAP: %v", err) } ldapAuth = la anon := *ldapBindDN == "" log.Printf("LDAP: %s:%d (TLS=%v) domain=%s anon=%v", la.host, la.port, la.useTLS, *ldapDomain, anon) } if ldapAuth == nil && !*kerberosOn && *htpasswdPath == "" { log.Fatal("no authentication providers configured: use at least one of --ldap, --ldap-uri, --kerberos, or --htpasswd") } // ── Kerberos ────────────────────────────────────────────────────────────── var krbAuth *KerberosAuth if *kerberosOn { ka, err := NewKerberosAuth(*keytabPath, *spn) if err != nil { log.Fatalf("Kerberos: %v", err) } krbAuth = ka log.Printf("Kerberos: keytab=%s SPN=%s realm=%s", *keytabPath, *spn, realm) } // ── htpasswd ────────────────────────────────────────────────────────────── var htpasswdAuth *HtpasswdAuth if *htpasswdPath != "" { ha, err := NewHtpasswdAuth(*htpasswdPath) if err != nil { log.Fatalf("htpasswd: %v", err) } htpasswdAuth = ha log.Printf("htpasswd: %s (%d entries)", *htpasswdPath, ha.Len()) // SIGHUP reloads the file — no restart needed to add/remove users. sighup := make(chan os.Signal, 1) signal.Notify(sighup, syscall.SIGHUP) go func() { for range sighup { if err := htpasswdAuth.Reload(); err != nil { log.Printf("SIGHUP: htpasswd reload failed: %v", err) } else { log.Printf("SIGHUP: htpasswd reloaded (%d entries)", htpasswdAuth.Len()) } } }() } // ── Handler ─────────────────────────────────────────────────────────────── h, err := NewHandler(ldapAuth, krbAuth, htpasswdAuth, &CertConfig{ ServerURL: *k3sServer, ServerCACert: *serverCACert, ClientCACert: *clientCACert, ClientCAKey: *clientCAKey, Duration: *certDuration, ClusterName: *clusterName, }) if err != nil { log.Fatalf("handler: %v", err) } // ── TLS ─────────────────────────────────────────────────────────────────── // Verify the cert is readable at startup; fail loudly rather than on first connection. if _, err := tls.LoadX509KeyPair(*tlsCert, *tlsKey); err != nil { log.Fatalf("TLS: loading certificate: %v", err) } // Reload from disk on every handshake — Let's Encrypt renewals are picked up // automatically without restarting the service. tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { cert, err := tls.LoadX509KeyPair(*tlsCert, *tlsKey) if err != nil { return nil, fmt.Errorf("reloading TLS cert: %w", err) } return &cert, nil }, } ln, err := tls.Listen("tcp", *addr, tlsConfig) if err != nil { log.Fatalf("listen %s: %v", *addr, err) } mux := http.NewServeMux() mux.HandleFunc("/kubeconfig", h.ServeHTTP) mux.HandleFunc("/credential", h.ServeCredential) mux.HandleFunc("/bootstrap", h.ServeBootstrap) mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { http.NotFound(w, r) return } w.Header().Set("Content-Type", "text/plain; charset=utf-8") fmt.Fprintf(w, "ward — Kubernetes credential gateway\n\n"+ " GET /bootstrap kubeconfig with exec plugin pre-wired (no auth required)\n"+ " GET /credential ExecCredential JSON for kubectl exec plugin\n"+ " GET /kubeconfig kubeconfig with embedded client certificate\n") }) srv := &http.Server{ Handler: mux, ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, } log.Printf("ward listening on %s (cert-duration=%s)", *addr, *certDuration) log.Fatal(srv.Serve(ln)) } // detectFQDN returns the fully-qualified domain name of the local host, // falling back to the short hostname if DNS resolution fails. func detectFQDN() string { hostname, err := os.Hostname() if err != nil { return "localhost" } if strings.Contains(hostname, ".") { return hostname } addrs, err := net.LookupHost(hostname) if err != nil || len(addrs) == 0 { return hostname } names, err := net.LookupAddr(addrs[0]) if err != nil || len(names) == 0 { return hostname } return strings.TrimSuffix(names[0], ".") } // domainPart strips the first label from a FQDN. // "host.example.com" → "example.com" func domainPart(fqdn string) string { if idx := strings.IndexByte(fqdn, '.'); idx >= 0 { return fqdn[idx+1:] } return fqdn } // firstLabel returns the first dot-separated label of a domain name. // "example.com" → "example" func firstLabel(domain string) string { if idx := strings.IndexByte(domain, '.'); idx >= 0 { return domain[:idx] } return domain }