Move server to a 'serve' command

This commit is contained in:
2026-03-17 12:31:51 +01:00
parent 2b2d59aa76
commit 0bc41725f8
4 changed files with 253 additions and 210 deletions

384
main.go
View File

@@ -2,7 +2,6 @@ package main
import (
"crypto/tls"
"flag"
"fmt"
"io"
"log"
@@ -13,183 +12,278 @@ import (
"strings"
"syscall"
"time"
"github.com/spf13/cobra"
)
// dbg is a package-level debug logger, active only when --debug is passed.
// Defaults to discarding output; main() points it at stderr when --debug is set.
// Defaults to discarding output; serve command points it at stderr when --debug is set.
var dbg = log.New(io.Discard, "[ward] ", 0)
func main() {
// ── Subcommand dispatch ───────────────────────────────────────────────────
// "ward credential ..." acts as a kubectl exec credential plugin.
// Handle it before parsing server flags so flag.Parse() doesn't choke on
// credential-specific flags.
if len(os.Args) > 1 && os.Args[1] == "credential" {
os.Exit(runCredential(os.Args[2:]))
root := &cobra.Command{
Use: "ward",
Short: "Kubernetes credential gateway",
SilenceUsage: true,
SilenceErrors: true,
}
// ── Server flags ──────────────────────────────────────────────────────────
root.AddCommand(newServeCmd())
root.AddCommand(newCredentialCmd())
if err := root.Execute(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func newServeCmd() *cobra.Command {
fqdn := detectFQDN()
domain := domainPart(fqdn)
realm := strings.ToUpper(domain)
var (
addr = flag.String("addr", ":8443", "Listen address")
tlsCert = flag.String("tls-cert", fmt.Sprintf("/etc/letsencrypt/live/%s/fullchain.pem", fqdn), "TLS certificate file (Let's Encrypt fullchain)")
tlsKey = flag.String("tls-key", fmt.Sprintf("/etc/letsencrypt/live/%s/privkey.pem", fqdn), "TLS private key file")
k3sServer = flag.String("k3s-server", "", "k3s API server URL written into returned kubeconfigs (required, e.g. https://k3s.example.com:6443)")
serverCACert = flag.String("server-ca-cert", "/var/lib/rancher/k3s/server/tls/server-ca.crt", "k3s server CA certificate (embedded in returned kubeconfig)")
clientCACert = flag.String("client-ca-cert", "/var/lib/rancher/k3s/server/tls/client-ca.crt", "k3s client CA certificate (signs user certs)")
clientCAKey = flag.String("client-ca-key", "/var/lib/rancher/k3s/server/tls/client-ca.key", "k3s client CA key")
certDuration = flag.Duration("cert-duration", 24*time.Hour, "Validity period of generated client certificates")
clusterName = flag.String("cluster-name", firstLabel(domain), "Cluster/context name written into generated kubeconfigs")
addr string
tlsCert string
tlsKey string
k3sServer string
serverCACert string
clientCACert string
clientCAKey string
certDuration time.Duration
clusterName string
// LDAP (opt-in; auto-discovered via DNS SRV unless --ldap-uri is given)
ldapDomain = flag.String("domain", domain, "Domain for LDAP SRV discovery and Kerberos realm derivation")
ldapOn = flag.Bool("ldap", false, "Enable LDAP authentication (auto-discovered via DNS SRV)")
ldapURI = flag.String("ldap-uri", "", "LDAP server URI, e.g. ldaps://ldap.example.com (implies --ldap; overrides DNS SRV)")
ldapBindDN = flag.String("ldap-bind-dn", os.Getenv("WARD_LDAP_BIND_DN"), "LDAP bind DN for search (default: anonymous; env: WARD_LDAP_BIND_DN)")
ldapBindPassword = flag.String("ldap-bind-password", os.Getenv("WARD_LDAP_BIND_PASSWORD"), "LDAP bind password (env: WARD_LDAP_BIND_PASSWORD; caution: visible in ps output)")
ldapDomain string
ldapOn bool
ldapURI string
ldapBindDN string
ldapBindPassword string
// Kerberos SPNEGO (opt-in)
kerberosOn = flag.Bool("kerberos", false, "Enable Kerberos SPNEGO authentication (Authorization: Negotiate)")
keytabPath = flag.String("keytab", "/etc/krb5.keytab", "Kerberos service keytab path")
spn = flag.String("spn", "HTTP/"+fqdn, fmt.Sprintf("Kerberos service principal name (SPN)\n\t\t\t(default realm %s — create with: kadmin: addprinc -randkey HTTP/%s@%s)", realm, fqdn, realm))
kerberosOn bool
keytabPath string
spn string
// htpasswd (opt-in)
htpasswdPath = flag.String("htpasswd", "", "Path to an Apache-compatible htpasswd file (bcrypt recommended: htpasswd -B -c file user)")
// Debug logging
debugFlag = flag.Bool("debug", os.Getenv("WARD_DEBUG") != "", "Enable verbose debug logging (also: $WARD_DEBUG=1)")
htpasswdPath string
debugFlag bool
)
flag.Parse()
if *debugFlag {
dbg = log.New(os.Stderr, "[ward] ", log.Ltime)
dbg.Print("debug logging enabled")
}
cmd := &cobra.Command{
Use: "serve",
Short: "Run the ward authentication server",
RunE: func(cmd *cobra.Command, args []string) error {
if debugFlag {
dbg = log.New(os.Stderr, "[ward] ", log.Ltime)
dbg.Print("debug logging enabled")
}
if *k3sServer == "" {
log.Fatal("--k3s-server is required (e.g. https://k3s.example.com:6443)")
}
if k3sServer == "" {
return fmt.Errorf("--k3s-server is required (e.g. https://k3s.example.com:6443)")
}
// ── LDAP ──────────────────────────────────────────────────────────────────
var ldapAuth *LDAPAuth
if *ldapURI != "" || *ldapOn {
var la *LDAPAuth
var err error
if *ldapURI != "" {
la, err = NewLDAPAuthFromURI(*ldapURI, *ldapDomain, *ldapBindDN, *ldapBindPassword)
} else {
la, err = NewLDAPAuth(*ldapDomain, *ldapBindDN, *ldapBindPassword)
}
if err != nil {
log.Fatalf("LDAP: %v", err)
}
ldapAuth = la
anon := *ldapBindDN == ""
log.Printf("LDAP: %s:%d (TLS=%v) domain=%s anon=%v", la.host, la.port, la.useTLS, *ldapDomain, anon)
}
if ldapAuth == nil && !*kerberosOn && *htpasswdPath == "" {
log.Fatal("no authentication providers configured: use at least one of --ldap, --ldap-uri, --kerberos, or --htpasswd")
}
// ── Kerberos ──────────────────────────────────────────────────────────────
var krbAuth *KerberosAuth
if *kerberosOn {
ka, err := NewKerberosAuth(*keytabPath, *spn)
if err != nil {
log.Fatalf("Kerberos: %v", err)
}
krbAuth = ka
log.Printf("Kerberos: keytab=%s SPN=%s realm=%s", *keytabPath, *spn, realm)
}
// ── htpasswd ──────────────────────────────────────────────────────────────
var htpasswdAuth *HtpasswdAuth
if *htpasswdPath != "" {
ha, err := NewHtpasswdAuth(*htpasswdPath)
if err != nil {
log.Fatalf("htpasswd: %v", err)
}
htpasswdAuth = ha
log.Printf("htpasswd: %s (%d entries)", *htpasswdPath, ha.Len())
// SIGHUP reloads the file — no restart needed to add/remove users.
sighup := make(chan os.Signal, 1)
signal.Notify(sighup, syscall.SIGHUP)
go func() {
for range sighup {
if err := htpasswdAuth.Reload(); err != nil {
log.Printf("SIGHUP: htpasswd reload failed: %v", err)
// ── LDAP ──────────────────────────────────────────────────────────────────
var ldapAuth *LDAPAuth
if ldapURI != "" || ldapOn {
var la *LDAPAuth
var err error
if ldapURI != "" {
la, err = NewLDAPAuthFromURI(ldapURI, ldapDomain, ldapBindDN, ldapBindPassword)
} else {
log.Printf("SIGHUP: htpasswd reloaded (%d entries)", htpasswdAuth.Len())
la, err = NewLDAPAuth(ldapDomain, ldapBindDN, ldapBindPassword)
}
if err != nil {
return fmt.Errorf("LDAP: %w", err)
}
ldapAuth = la
anon := ldapBindDN == ""
log.Printf("LDAP: %s:%d (TLS=%v) domain=%s anon=%v", la.host, la.port, la.useTLS, ldapDomain, anon)
}
}()
}
// ── Handler ───────────────────────────────────────────────────────────────
h, err := NewHandler(ldapAuth, krbAuth, htpasswdAuth, &CertConfig{
ServerURL: *k3sServer,
ServerCACert: *serverCACert,
ClientCACert: *clientCACert,
ClientCAKey: *clientCAKey,
Duration: *certDuration,
ClusterName: *clusterName,
})
if err != nil {
log.Fatalf("handler: %v", err)
}
if ldapAuth == nil && !kerberosOn && htpasswdPath == "" {
return fmt.Errorf("no authentication providers configured: use at least one of --ldap, --ldap-uri, --kerberos, or --htpasswd")
}
// ── TLS ───────────────────────────────────────────────────────────────────
// Verify the cert is readable at startup; fail loudly rather than on first connection.
if _, err := tls.LoadX509KeyPair(*tlsCert, *tlsKey); err != nil {
log.Fatalf("TLS: loading certificate: %v", err)
}
// Reload from disk on every handshake — Let's Encrypt renewals are picked up
// automatically without restarting the service.
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(*tlsCert, *tlsKey)
// ── Kerberos ──────────────────────────────────────────────────────────────
var krbAuth *KerberosAuth
if kerberosOn {
ka, err := NewKerberosAuth(keytabPath, spn)
if err != nil {
return fmt.Errorf("Kerberos: %w", err)
}
krbAuth = ka
log.Printf("Kerberos: keytab=%s SPN=%s realm=%s", keytabPath, spn, realm)
}
// ── htpasswd ──────────────────────────────────────────────────────────────
var htpasswdAuth *HtpasswdAuth
if htpasswdPath != "" {
ha, err := NewHtpasswdAuth(htpasswdPath)
if err != nil {
return fmt.Errorf("htpasswd: %w", err)
}
htpasswdAuth = ha
log.Printf("htpasswd: %s (%d entries)", htpasswdPath, ha.Len())
sighup := make(chan os.Signal, 1)
signal.Notify(sighup, syscall.SIGHUP)
go func() {
for range sighup {
if err := htpasswdAuth.Reload(); err != nil {
log.Printf("SIGHUP: htpasswd reload failed: %v", err)
} else {
log.Printf("SIGHUP: htpasswd reloaded (%d entries)", htpasswdAuth.Len())
}
}
}()
}
// ── Handler ───────────────────────────────────────────────────────────────
h, err := NewHandler(ldapAuth, krbAuth, htpasswdAuth, &CertConfig{
ServerURL: k3sServer,
ServerCACert: serverCACert,
ClientCACert: clientCACert,
ClientCAKey: clientCAKey,
Duration: certDuration,
ClusterName: clusterName,
})
if err != nil {
return nil, fmt.Errorf("reloading TLS cert: %w", err)
return fmt.Errorf("handler: %w", err)
}
return &cert, nil
// ── TLS ───────────────────────────────────────────────────────────────────
if _, err := tls.LoadX509KeyPair(tlsCert, tlsKey); err != nil {
return fmt.Errorf("TLS: loading certificate: %w", err)
}
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(tlsCert, tlsKey)
if err != nil {
return nil, fmt.Errorf("reloading TLS cert: %w", err)
}
return &cert, nil
},
}
ln, err := tls.Listen("tcp", addr, tlsConfig)
if err != nil {
return fmt.Errorf("listen %s: %w", addr, err)
}
mux := http.NewServeMux()
mux.HandleFunc("/kubeconfig", h.ServeHTTP)
mux.HandleFunc("/credential", h.ServeCredential)
mux.HandleFunc("/bootstrap", h.ServeBootstrap)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprintf(w, "ward — Kubernetes credential gateway\n\n"+
" GET /bootstrap kubeconfig with exec plugin pre-wired (no auth required)\n"+
" GET /credential ExecCredential JSON for kubectl exec plugin\n"+
" GET /kubeconfig kubeconfig with embedded client certificate\n")
})
srv := &http.Server{
Handler: mux,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
}
log.Printf("ward listening on %s (cert-duration=%s)", addr, certDuration)
return srv.Serve(ln)
},
}
ln, err := tls.Listen("tcp", *addr, tlsConfig)
if err != nil {
log.Fatalf("listen %s: %v", *addr, err)
cmd.Flags().StringVar(&addr, "addr", ":8443", "Listen address")
cmd.Flags().StringVar(&tlsCert, "tls-cert", fmt.Sprintf("/etc/letsencrypt/live/%s/fullchain.pem", fqdn), "TLS certificate file (Let's Encrypt fullchain)")
cmd.Flags().StringVar(&tlsKey, "tls-key", fmt.Sprintf("/etc/letsencrypt/live/%s/privkey.pem", fqdn), "TLS private key file")
cmd.Flags().StringVar(&k3sServer, "k3s-server", "", "k3s API server URL written into returned kubeconfigs (required, e.g. https://k3s.example.com:6443)")
cmd.Flags().StringVar(&serverCACert, "server-ca-cert", "/var/lib/rancher/k3s/server/tls/server-ca.crt", "k3s server CA certificate (embedded in returned kubeconfig)")
cmd.Flags().StringVar(&clientCACert, "client-ca-cert", "/var/lib/rancher/k3s/server/tls/client-ca.crt", "k3s client CA certificate (signs user certs)")
cmd.Flags().StringVar(&clientCAKey, "client-ca-key", "/var/lib/rancher/k3s/server/tls/client-ca.key", "k3s client CA key")
cmd.Flags().DurationVar(&certDuration, "cert-duration", 24*time.Hour, "Validity period of generated client certificates")
cmd.Flags().StringVar(&clusterName, "cluster-name", firstLabel(domain), "Cluster/context name written into generated kubeconfigs")
cmd.Flags().StringVar(&ldapDomain, "domain", domain, "Domain for LDAP SRV discovery and Kerberos realm derivation")
cmd.Flags().BoolVar(&ldapOn, "ldap", false, "Enable LDAP authentication (auto-discovered via DNS SRV)")
cmd.Flags().StringVar(&ldapURI, "ldap-uri", "", "LDAP server URI, e.g. ldaps://ldap.example.com (implies --ldap; overrides DNS SRV)")
cmd.Flags().StringVar(&ldapBindDN, "ldap-bind-dn", os.Getenv("WARD_LDAP_BIND_DN"), "LDAP bind DN for search (default: anonymous; env: WARD_LDAP_BIND_DN)")
cmd.Flags().StringVar(&ldapBindPassword, "ldap-bind-password", os.Getenv("WARD_LDAP_BIND_PASSWORD"), "LDAP bind password (env: WARD_LDAP_BIND_PASSWORD; caution: visible in ps output)")
cmd.Flags().BoolVar(&kerberosOn, "kerberos", false, "Enable Kerberos SPNEGO authentication (Authorization: Negotiate)")
cmd.Flags().StringVar(&keytabPath, "keytab", "/etc/krb5.keytab", "Kerberos service keytab path")
cmd.Flags().StringVar(&spn, "spn", "HTTP/"+fqdn, fmt.Sprintf("Kerberos service principal name (SPN) (default realm %s — create with: kadmin: addprinc -randkey HTTP/%s@%s)", realm, fqdn, realm))
cmd.Flags().StringVar(&htpasswdPath, "htpasswd", "", "Path to an Apache-compatible htpasswd file (bcrypt recommended: htpasswd -B -c file user)")
cmd.Flags().BoolVar(&debugFlag, "debug", os.Getenv("WARD_DEBUG") != "", "Enable verbose debug logging (also: $WARD_DEBUG=1)")
return cmd
}
func newCredentialCmd() *cobra.Command {
var (
server string
username string
noKerberos bool
noCache bool
debugFlag bool
)
cmd := &cobra.Command{
Use: "credential",
Short: "kubectl exec credential plugin — fetch an ExecCredential from ward",
Long: `Acts as a kubectl exec credential plugin. Fetches an ExecCredential JSON
from the ward server and prints it to stdout for kubectl to consume.
Authentication priority:
1. Kerberos SPNEGO using the active credential cache (from kinit)
2. Basic auth — prompts for password, or reads $WARD_PASSWORD
Credentials are cached in ~/.cache/ward/ and reused until 5 minutes
before expiry, so kubectl invocations are fast after the first call.
Debug output goes to stderr (kubectl surfaces this to the terminal):
WARD_DEBUG=1 kubectl get nodes`,
RunE: func(cmd *cobra.Command, args []string) error {
if server == "" {
return fmt.Errorf("--server is required")
}
logf := func(format string, a ...interface{}) {
if debugFlag {
fmt.Fprintf(os.Stderr, "[ward] "+format+"\n", a...)
}
}
if !noCache {
if ec, ok := credReadCache(server, logf); ok {
return credPrint(ec)
}
}
ec, err := credFetch(server, username, noKerberos, logf)
if err != nil {
return err
}
if !noCache {
credWriteCache(server, ec, logf)
}
return credPrint(ec)
},
}
mux := http.NewServeMux()
mux.HandleFunc("/kubeconfig", h.ServeHTTP)
mux.HandleFunc("/credential", h.ServeCredential)
mux.HandleFunc("/bootstrap", h.ServeBootstrap)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprintf(w, "ward — Kubernetes credential gateway\n\n"+
" GET /bootstrap kubeconfig with exec plugin pre-wired (no auth required)\n"+
" GET /credential ExecCredential JSON for kubectl exec plugin\n"+
" GET /kubeconfig kubeconfig with embedded client certificate\n")
})
cmd.Flags().StringVar(&server, "server", "", "ward server URL (required)")
cmd.Flags().StringVar(&username, "username", "", "username for Basic auth fallback (default: $USER)")
cmd.Flags().BoolVar(&noKerberos, "no-kerberos", false, "skip Kerberos; always use Basic auth")
cmd.Flags().BoolVar(&noCache, "no-cache", false, "bypass local cache; always fetch a fresh credential")
cmd.Flags().BoolVar(&debugFlag, "debug", os.Getenv("WARD_DEBUG") != "", "verbose debug output to stderr (also: $WARD_DEBUG=1)")
srv := &http.Server{
Handler: mux,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
}
log.Printf("ward listening on %s (cert-duration=%s)", *addr, *certDuration)
log.Fatal(srv.Serve(ln))
return cmd
}
// detectFQDN returns the fully-qualified domain name of the local host,