package main import ( "crypto/tls" "fmt" "io" "log" "net" "net/http" "os" "os/signal" "strings" "syscall" "time" "github.com/spf13/cobra" ) // dbg is a package-level debug logger, active only when --debug is passed. // Defaults to discarding output; serve command points it at stderr when --debug is set. var dbg = log.New(io.Discard, "[ward] ", 0) func main() { root := &cobra.Command{ Use: "ward", Short: "Kubernetes credential gateway", SilenceUsage: true, SilenceErrors: true, } root.AddCommand(newServeCmd()) root.AddCommand(newCredentialCmd()) if err := root.Execute(); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } } func newServeCmd() *cobra.Command { fqdn := detectFQDN() domain := domainPart(fqdn) realm := strings.ToUpper(domain) var ( addr string tlsCert string tlsKey string k3sServer string serverCACert string clientCACert string clientCAKey string certDuration time.Duration clusterName string ldapDomain string ldapOn bool ldapURI string ldapBindDN string ldapBindPassword string kerberosOn bool keytabPath string spn string htpasswdPath string debugFlag bool ) cmd := &cobra.Command{ Use: "serve", Short: "Run the ward authentication server", RunE: func(cmd *cobra.Command, args []string) error { if debugFlag { dbg = log.New(os.Stderr, "[ward] ", log.Ltime) dbg.Print("debug logging enabled") } if k3sServer == "" { return fmt.Errorf("--k3s-server is required (e.g. https://k3s.example.com:6443)") } // ── LDAP ────────────────────────────────────────────────────────────────── var ldapAuth *LDAPAuth if ldapURI != "" || ldapOn { var la *LDAPAuth var err error if ldapURI != "" { la, err = NewLDAPAuthFromURI(ldapURI, ldapDomain, ldapBindDN, ldapBindPassword) } else { la, err = NewLDAPAuth(ldapDomain, ldapBindDN, ldapBindPassword) } if err != nil { return fmt.Errorf("LDAP: %w", err) } ldapAuth = la anon := ldapBindDN == "" log.Printf("LDAP: %s:%d (TLS=%v) domain=%s anon=%v", la.host, la.port, la.useTLS, ldapDomain, anon) } if ldapAuth == nil && !kerberosOn && htpasswdPath == "" { return fmt.Errorf("no authentication providers configured: use at least one of --ldap, --ldap-uri, --kerberos, or --htpasswd") } // ── Kerberos ────────────────────────────────────────────────────────────── var krbAuth *KerberosAuth if kerberosOn { ka, err := NewKerberosAuth(keytabPath, spn) if err != nil { return fmt.Errorf("Kerberos: %w", err) } krbAuth = ka log.Printf("Kerberos: keytab=%s SPN=%s realm=%s", keytabPath, spn, realm) } // ── htpasswd ────────────────────────────────────────────────────────────── var htpasswdAuth *HtpasswdAuth if htpasswdPath != "" { ha, err := NewHtpasswdAuth(htpasswdPath) if err != nil { return fmt.Errorf("htpasswd: %w", err) } htpasswdAuth = ha log.Printf("htpasswd: %s (%d entries)", htpasswdPath, ha.Len()) sighup := make(chan os.Signal, 1) signal.Notify(sighup, syscall.SIGHUP) go func() { for range sighup { if err := htpasswdAuth.Reload(); err != nil { log.Printf("SIGHUP: htpasswd reload failed: %v", err) } else { log.Printf("SIGHUP: htpasswd reloaded (%d entries)", htpasswdAuth.Len()) } } }() } // ── Handler ─────────────────────────────────────────────────────────────── h, err := NewHandler(ldapAuth, krbAuth, htpasswdAuth, &CertConfig{ ServerURL: k3sServer, ServerCACert: serverCACert, ClientCACert: clientCACert, ClientCAKey: clientCAKey, Duration: certDuration, ClusterName: clusterName, }) if err != nil { return fmt.Errorf("handler: %w", err) } // ── TLS ─────────────────────────────────────────────────────────────────── if _, err := tls.LoadX509KeyPair(tlsCert, tlsKey); err != nil { return fmt.Errorf("TLS: loading certificate: %w", err) } tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { cert, err := tls.LoadX509KeyPair(tlsCert, tlsKey) if err != nil { return nil, fmt.Errorf("reloading TLS cert: %w", err) } return &cert, nil }, } ln, err := tls.Listen("tcp", addr, tlsConfig) if err != nil { return fmt.Errorf("listen %s: %w", addr, err) } mux := http.NewServeMux() mux.HandleFunc("/kubeconfig", h.ServeHTTP) mux.HandleFunc("/credential", h.ServeCredential) mux.HandleFunc("/bootstrap", h.ServeBootstrap) mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { http.NotFound(w, r) return } w.Header().Set("Content-Type", "text/plain; charset=utf-8") fmt.Fprintf(w, "ward — Kubernetes credential gateway\n\n"+ " GET /bootstrap kubeconfig with exec plugin pre-wired (no auth required)\n"+ " GET /credential ExecCredential JSON for kubectl exec plugin\n"+ " GET /kubeconfig kubeconfig with embedded client certificate\n") }) srv := &http.Server{ Handler: mux, ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, } log.Printf("ward listening on %s (cert-duration=%s)", addr, certDuration) return srv.Serve(ln) }, } cmd.Flags().StringVar(&addr, "addr", ":8443", "Listen address") cmd.Flags().StringVar(&tlsCert, "tls-cert", fmt.Sprintf("/etc/letsencrypt/live/%s/fullchain.pem", fqdn), "TLS certificate file (Let's Encrypt fullchain)") cmd.Flags().StringVar(&tlsKey, "tls-key", fmt.Sprintf("/etc/letsencrypt/live/%s/privkey.pem", fqdn), "TLS private key file") cmd.Flags().StringVar(&k3sServer, "k3s-server", "", "k3s API server URL written into returned kubeconfigs (required, e.g. https://k3s.example.com:6443)") cmd.Flags().StringVar(&serverCACert, "server-ca-cert", "/var/lib/rancher/k3s/server/tls/server-ca.crt", "k3s server CA certificate (embedded in returned kubeconfig)") cmd.Flags().StringVar(&clientCACert, "client-ca-cert", "/var/lib/rancher/k3s/server/tls/client-ca.crt", "k3s client CA certificate (signs user certs)") cmd.Flags().StringVar(&clientCAKey, "client-ca-key", "/var/lib/rancher/k3s/server/tls/client-ca.key", "k3s client CA key") cmd.Flags().DurationVar(&certDuration, "cert-duration", 24*time.Hour, "Validity period of generated client certificates") cmd.Flags().StringVar(&clusterName, "cluster-name", firstLabel(domain), "Cluster/context name written into generated kubeconfigs") cmd.Flags().StringVar(&ldapDomain, "domain", domain, "Domain for LDAP SRV discovery and Kerberos realm derivation") cmd.Flags().BoolVar(&ldapOn, "ldap", false, "Enable LDAP authentication (auto-discovered via DNS SRV)") cmd.Flags().StringVar(&ldapURI, "ldap-uri", "", "LDAP server URI, e.g. ldaps://ldap.example.com (implies --ldap; overrides DNS SRV)") cmd.Flags().StringVar(&ldapBindDN, "ldap-bind-dn", os.Getenv("WARD_LDAP_BIND_DN"), "LDAP bind DN for search (default: anonymous; env: WARD_LDAP_BIND_DN)") cmd.Flags().StringVar(&ldapBindPassword, "ldap-bind-password", os.Getenv("WARD_LDAP_BIND_PASSWORD"), "LDAP bind password (env: WARD_LDAP_BIND_PASSWORD; caution: visible in ps output)") cmd.Flags().BoolVar(&kerberosOn, "kerberos", false, "Enable Kerberos SPNEGO authentication (Authorization: Negotiate)") cmd.Flags().StringVar(&keytabPath, "keytab", "/etc/krb5.keytab", "Kerberos service keytab path") cmd.Flags().StringVar(&spn, "spn", "HTTP/"+fqdn, fmt.Sprintf("Kerberos service principal name (SPN) (default realm %s — create with: kadmin: addprinc -randkey HTTP/%s@%s)", realm, fqdn, realm)) cmd.Flags().StringVar(&htpasswdPath, "htpasswd", "", "Path to an Apache-compatible htpasswd file (bcrypt recommended: htpasswd -B -c file user)") cmd.Flags().BoolVar(&debugFlag, "debug", os.Getenv("WARD_DEBUG") != "", "Enable verbose debug logging (also: $WARD_DEBUG=1)") return cmd } func newCredentialCmd() *cobra.Command { var ( server string username string noKerberos bool noCache bool debugFlag bool ) cmd := &cobra.Command{ Use: "credential", Short: "kubectl exec credential plugin — fetch an ExecCredential from ward", Long: `Acts as a kubectl exec credential plugin. Fetches an ExecCredential JSON from the ward server and prints it to stdout for kubectl to consume. Authentication priority: 1. Kerberos SPNEGO using the active credential cache (from kinit) 2. Basic auth — prompts for password, or reads $WARD_PASSWORD Credentials are cached in ~/.cache/ward/ and reused until 5 minutes before expiry, so kubectl invocations are fast after the first call. Debug output goes to stderr (kubectl surfaces this to the terminal): WARD_DEBUG=1 kubectl get nodes`, RunE: func(cmd *cobra.Command, args []string) error { if server == "" { return fmt.Errorf("--server is required") } // prepend https if no scheme is given, for user convenience if !strings.Contains(server, "://") { server = "https://" + server } // append port 8443 if no port is given parts := strings.Split(server, ":") if len(parts) == 2 { server += ":8443" } logf := func(format string, a ...any) { if debugFlag { fmt.Fprintf(os.Stderr, "[ward] "+format+"\n", a...) } } if !noCache { if ec, ok := credReadCache(server, logf); ok { return credPrint(ec) } } ec, err := credFetch(server, username, noKerberos, logf) if err != nil { return err } if !noCache { credWriteCache(server, ec, logf) } return credPrint(ec) }, } cmd.Flags().StringVar(&server, "server", "", "ward server URL (required)") cmd.Flags().StringVar(&username, "username", "", "username for Basic auth fallback (default: $USER)") cmd.Flags().BoolVar(&noKerberos, "no-kerberos", false, "skip Kerberos; always use Basic auth") cmd.Flags().BoolVar(&noCache, "no-cache", false, "bypass local cache; always fetch a fresh credential") cmd.Flags().BoolVar(&debugFlag, "debug", os.Getenv("WARD_DEBUG") != "", "verbose debug output to stderr (also: $WARD_DEBUG=1)") return cmd } // detectFQDN returns the fully-qualified domain name of the local host, // falling back to the short hostname if DNS resolution fails. func detectFQDN() string { hostname, err := os.Hostname() if err != nil { return "localhost" } if strings.Contains(hostname, ".") { return hostname } addrs, err := net.LookupHost(hostname) if err != nil || len(addrs) == 0 { return hostname } names, err := net.LookupAddr(addrs[0]) if err != nil || len(names) == 0 { return hostname } return strings.TrimSuffix(names[0], ".") } // domainPart strips the first label from a FQDN. // "host.example.com" → "example.com" func domainPart(fqdn string) string { if idx := strings.IndexByte(fqdn, '.'); idx >= 0 { return fqdn[idx+1:] } return fqdn } // firstLabel returns the first dot-separated label of a domain name. // "example.com" → "example" func firstLabel(domain string) string { if idx := strings.IndexByte(domain, '.'); idx >= 0 { return domain[:idx] } return domain }