package cmd import ( "crypto/tls" "fmt" "io" "log" "net" "net/http" "os" "os/signal" "strings" "syscall" "time" "github.com/spf13/cobra" "git.shee.sh/james/ward/pkg/auth" "git.shee.sh/james/ward/pkg/cert" "git.shee.sh/james/ward/pkg/handler" ) func newServeCmd() *cobra.Command { fqdn := detectFQDN() domain := domainPart(fqdn) realm := strings.ToUpper(domain) var ( addr string tlsCert string tlsKey string k3sServer string serverCACert string clientCACert string clientCAKey string certDuration time.Duration clusterName string ldapDomain string ldapOn bool ldapURI string ldapBindDN string loginCertDuration time.Duration kerberosOn bool keytabPath string spn string htpasswdPath string debugFlag bool ) cmd := &cobra.Command{ Use: "serve", Short: "Run the ward authentication server", RunE: func(cmd *cobra.Command, args []string) error { dbg := log.New(io.Discard, "[ward] ", 0) if debugFlag { dbg = log.New(os.Stderr, "[ward] ", log.Ltime) dbg.Print("debug logging enabled") } if k3sServer == "" { return fmt.Errorf("--k3s-server is required (e.g. https://k3s.example.com:6443)") } // ── LDAP ────────────────────────────────────────────────────────────────── ldapBindPassword := os.Getenv("WARD_LDAP_BIND_PASSWORD") var ldapAuth *auth.LDAPAuth if ldapURI != "" || ldapOn { var la *auth.LDAPAuth var err error if ldapURI != "" { la, err = auth.NewLDAPAuthFromURI(ldapURI, ldapDomain, ldapBindDN, ldapBindPassword, dbg) } else { la, err = auth.NewLDAPAuth(ldapDomain, ldapBindDN, ldapBindPassword, dbg) } if err != nil { return fmt.Errorf("LDAP: %w", err) } ldapAuth = la anon := ldapBindDN == "" log.Printf("LDAP: %s:%d (TLS=%v) domain=%s anon=%v", la.Host(), la.Port(), la.UseTLS(), ldapDomain, anon) } if ldapAuth == nil && !kerberosOn && htpasswdPath == "" { return fmt.Errorf("no authentication providers configured: use at least one of --ldap, --ldap-uri, --kerberos, or --htpasswd") } // ── Kerberos ────────────────────────────────────────────────────────────── var krbAuth *auth.KerberosAuth if kerberosOn { ka, err := auth.NewKerberosAuth(keytabPath, spn) if err != nil { return fmt.Errorf("Kerberos: %w", err) } krbAuth = ka log.Printf("Kerberos: keytab=%s SPN=%s realm=%s", keytabPath, spn, realm) } // ── htpasswd ────────────────────────────────────────────────────────────── var htpasswdAuth *auth.HtpasswdAuth if htpasswdPath != "" { ha, err := auth.NewHtpasswdAuth(htpasswdPath) if err != nil { return fmt.Errorf("htpasswd: %w", err) } htpasswdAuth = ha log.Printf("htpasswd: %s (%d entries)", htpasswdPath, ha.Len()) sighup := make(chan os.Signal, 1) signal.Notify(sighup, syscall.SIGHUP) go func() { for range sighup { if err := htpasswdAuth.Reload(); err != nil { log.Printf("SIGHUP: htpasswd reload failed: %v", err) } else { log.Printf("SIGHUP: htpasswd reloaded (%d entries)", htpasswdAuth.Len()) } } }() } // ── Handler ─────────────────────────────────────────────────────────────── h, err := handler.NewHandler(ldapAuth, krbAuth, htpasswdAuth, &cert.CertConfig{ ServerURL: k3sServer, ServerCACert: serverCACert, ClientCACert: clientCACert, ClientCAKey: clientCAKey, Duration: certDuration, LoginDuration: loginCertDuration, ClusterName: clusterName, }, dbg) if err != nil { return fmt.Errorf("handler: %w", err) } // ── TLS ─────────────────────────────────────────────────────────────────── if _, err := tls.LoadX509KeyPair(tlsCert, tlsKey); err != nil { return fmt.Errorf("TLS: loading certificate: %w", err) } tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS12, GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { c, err := tls.LoadX509KeyPair(tlsCert, tlsKey) if err != nil { return nil, fmt.Errorf("reloading TLS cert: %w", err) } return &c, nil }, } ln, err := tls.Listen("tcp", addr, tlsConfig) if err != nil { return fmt.Errorf("listen %s: %w", addr, err) } mux := http.NewServeMux() mux.HandleFunc("/kubeconfig", h.ServeHTTP) mux.HandleFunc("/credential", h.ServeCredential) mux.HandleFunc("/bootstrap", h.ServeBootstrap) mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { http.NotFound(w, r) return } w.Header().Set("Content-Type", "text/plain; charset=utf-8") fmt.Fprintf(w, "ward — Kubernetes credential gateway\n\n"+ " GET /bootstrap kubeconfig with exec plugin pre-wired (no auth required)\n"+ " GET /credential ExecCredential JSON for kubectl exec plugin\n"+ " GET /kubeconfig kubeconfig with embedded client certificate\n") }) srv := &http.Server{ Handler: mux, ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, } log.Printf("ward listening on %s (cert-duration=%s)", addr, certDuration) return srv.Serve(ln) }, } cmd.Flags().StringVar(&addr, "addr", ":8443", "Listen address") cmd.Flags().StringVar(&tlsCert, "tls-cert", fmt.Sprintf("/etc/letsencrypt/live/%s/fullchain.pem", fqdn), "TLS certificate file (Let's Encrypt fullchain)") cmd.Flags().StringVar(&tlsKey, "tls-key", fmt.Sprintf("/etc/letsencrypt/live/%s/privkey.pem", fqdn), "TLS private key file") cmd.Flags().StringVar(&k3sServer, "k3s-server", "", "k3s API server URL written into returned kubeconfigs (required, e.g. https://k3s.example.com:6443)") cmd.Flags().StringVar(&serverCACert, "server-ca-cert", "/var/lib/rancher/k3s/server/tls/server-ca.crt", "k3s server CA certificate (embedded in returned kubeconfig)") cmd.Flags().StringVar(&clientCACert, "client-ca-cert", "/var/lib/rancher/k3s/server/tls/client-ca.crt", "k3s client CA certificate (signs user certs)") cmd.Flags().StringVar(&clientCAKey, "client-ca-key", "/var/lib/rancher/k3s/server/tls/client-ca.key", "k3s client CA key") cmd.Flags().DurationVar(&certDuration, "cert-duration", 24*time.Hour, "Validity period of generated client certificates") cmd.Flags().DurationVar(&loginCertDuration, "login-cert-duration", 168*time.Hour, "Validity period of certificates issued by 'ward login' (default 7 days)") cmd.Flags().StringVar(&clusterName, "cluster-name", firstLabel(domain), "Cluster/context name written into generated kubeconfigs") cmd.Flags().StringVar(&ldapDomain, "domain", domain, "Domain for LDAP SRV discovery and Kerberos realm derivation") cmd.Flags().BoolVar(&ldapOn, "ldap", false, "Enable LDAP authentication (auto-discovered via DNS SRV)") cmd.Flags().StringVar(&ldapURI, "ldap-uri", "", "LDAP server URI, e.g. ldaps://ldap.example.com (implies --ldap; overrides DNS SRV)") cmd.Flags().StringVar(&ldapBindDN, "ldap-bind-dn", os.Getenv("WARD_LDAP_BIND_DN"), "LDAP bind DN for search (default: anonymous; env: WARD_LDAP_BIND_DN)") // LDAP bind password is read exclusively from $WARD_LDAP_BIND_PASSWORD to avoid // exposure in process listings. cmd.Flags().BoolVar(&kerberosOn, "kerberos", false, "Enable Kerberos SPNEGO authentication (Authorization: Negotiate)") cmd.Flags().StringVar(&keytabPath, "keytab", "/etc/krb5.keytab", "Kerberos service keytab path") cmd.Flags().StringVar(&spn, "spn", "HTTP/"+fqdn, fmt.Sprintf("Kerberos service principal name (SPN) (default realm %s — create with: kadmin: addprinc -randkey HTTP/%s@%s)", realm, fqdn, realm)) cmd.Flags().StringVar(&htpasswdPath, "htpasswd", "", "Path to an Apache-compatible htpasswd file (bcrypt recommended: htpasswd -B -c file user)") cmd.Flags().BoolVar(&debugFlag, "debug", os.Getenv("WARD_DEBUG") != "", "Enable verbose debug logging (also: $WARD_DEBUG=1)") return cmd } // detectFQDN returns the fully-qualified domain name of the local host, // falling back to the short hostname if DNS resolution fails. func detectFQDN() string { hostname, err := os.Hostname() if err != nil { return "localhost" } if strings.Contains(hostname, ".") { return hostname } addrs, err := net.LookupHost(hostname) if err != nil || len(addrs) == 0 { return hostname } names, err := net.LookupAddr(addrs[0]) if err != nil || len(names) == 0 { return hostname } return strings.TrimSuffix(names[0], ".") } // domainPart strips the first label from a FQDN. // "host.example.com" → "example.com" func domainPart(fqdn string) string { if idx := strings.IndexByte(fqdn, '.'); idx >= 0 { return fqdn[idx+1:] } return fqdn } // firstLabel returns the first dot-separated label of a domain name. // "example.com" → "example" func firstLabel(domain string) string { if idx := strings.IndexByte(domain, '.'); idx >= 0 { return domain[:idx] } return domain }