Files
ward/cmd/serve.go
James McDonald a0a7932f99
All checks were successful
Release / release (push) Successful in 1m38s
Refactor project layout
2026-04-01 13:16:06 +02:00

253 lines
9.4 KiB
Go

package cmd
import (
"crypto/tls"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/spf13/cobra"
"git.shee.sh/james/ward/pkg/auth"
"git.shee.sh/james/ward/pkg/cert"
"git.shee.sh/james/ward/pkg/handler"
)
func newServeCmd() *cobra.Command {
fqdn := detectFQDN()
domain := domainPart(fqdn)
realm := strings.ToUpper(domain)
var (
addr string
tlsCert string
tlsKey string
k3sServer string
serverCACert string
clientCACert string
clientCAKey string
certDuration time.Duration
clusterName string
ldapDomain string
ldapOn bool
ldapURI string
ldapBindDN string
loginCertDuration time.Duration
kerberosOn bool
keytabPath string
spn string
htpasswdPath string
debugFlag bool
)
cmd := &cobra.Command{
Use: "serve",
Short: "Run the ward authentication server",
RunE: func(cmd *cobra.Command, args []string) error {
dbg := log.New(io.Discard, "[ward] ", 0)
if debugFlag {
dbg = log.New(os.Stderr, "[ward] ", log.Ltime)
dbg.Print("debug logging enabled")
}
if k3sServer == "" {
return fmt.Errorf("--k3s-server is required (e.g. https://k3s.example.com:6443)")
}
// ── LDAP ──────────────────────────────────────────────────────────────────
ldapBindPassword := os.Getenv("WARD_LDAP_BIND_PASSWORD")
var ldapAuth *auth.LDAPAuth
if ldapURI != "" || ldapOn {
var la *auth.LDAPAuth
var err error
if ldapURI != "" {
la, err = auth.NewLDAPAuthFromURI(ldapURI, ldapDomain, ldapBindDN, ldapBindPassword, dbg)
} else {
la, err = auth.NewLDAPAuth(ldapDomain, ldapBindDN, ldapBindPassword, dbg)
}
if err != nil {
return fmt.Errorf("LDAP: %w", err)
}
ldapAuth = la
anon := ldapBindDN == ""
log.Printf("LDAP: %s:%d (TLS=%v) domain=%s anon=%v", la.Host(), la.Port(), la.UseTLS(), ldapDomain, anon)
}
if ldapAuth == nil && !kerberosOn && htpasswdPath == "" {
return fmt.Errorf("no authentication providers configured: use at least one of --ldap, --ldap-uri, --kerberos, or --htpasswd")
}
// ── Kerberos ──────────────────────────────────────────────────────────────
var krbAuth *auth.KerberosAuth
if kerberosOn {
ka, err := auth.NewKerberosAuth(keytabPath, spn)
if err != nil {
return fmt.Errorf("Kerberos: %w", err)
}
krbAuth = ka
log.Printf("Kerberos: keytab=%s SPN=%s realm=%s", keytabPath, spn, realm)
}
// ── htpasswd ──────────────────────────────────────────────────────────────
var htpasswdAuth *auth.HtpasswdAuth
if htpasswdPath != "" {
ha, err := auth.NewHtpasswdAuth(htpasswdPath)
if err != nil {
return fmt.Errorf("htpasswd: %w", err)
}
htpasswdAuth = ha
log.Printf("htpasswd: %s (%d entries)", htpasswdPath, ha.Len())
sighup := make(chan os.Signal, 1)
signal.Notify(sighup, syscall.SIGHUP)
go func() {
for range sighup {
if err := htpasswdAuth.Reload(); err != nil {
log.Printf("SIGHUP: htpasswd reload failed: %v", err)
} else {
log.Printf("SIGHUP: htpasswd reloaded (%d entries)", htpasswdAuth.Len())
}
}
}()
}
// ── Handler ───────────────────────────────────────────────────────────────
h, err := handler.NewHandler(ldapAuth, krbAuth, htpasswdAuth, &cert.CertConfig{
ServerURL: k3sServer,
ServerCACert: serverCACert,
ClientCACert: clientCACert,
ClientCAKey: clientCAKey,
Duration: certDuration,
LoginDuration: loginCertDuration,
ClusterName: clusterName,
}, dbg)
if err != nil {
return fmt.Errorf("handler: %w", err)
}
// ── TLS ───────────────────────────────────────────────────────────────────
if _, err := tls.LoadX509KeyPair(tlsCert, tlsKey); err != nil {
return fmt.Errorf("TLS: loading certificate: %w", err)
}
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
c, err := tls.LoadX509KeyPair(tlsCert, tlsKey)
if err != nil {
return nil, fmt.Errorf("reloading TLS cert: %w", err)
}
return &c, nil
},
}
ln, err := tls.Listen("tcp", addr, tlsConfig)
if err != nil {
return fmt.Errorf("listen %s: %w", addr, err)
}
mux := http.NewServeMux()
mux.HandleFunc("/kubeconfig", h.ServeHTTP)
mux.HandleFunc("/credential", h.ServeCredential)
mux.HandleFunc("/bootstrap", h.ServeBootstrap)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprintf(w, "ward — Kubernetes credential gateway\n\n"+
" GET /bootstrap kubeconfig with exec plugin pre-wired (no auth required)\n"+
" GET /credential ExecCredential JSON for kubectl exec plugin\n"+
" GET /kubeconfig kubeconfig with embedded client certificate\n")
})
srv := &http.Server{
Handler: mux,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
}
log.Printf("ward listening on %s (cert-duration=%s)", addr, certDuration)
return srv.Serve(ln)
},
}
cmd.Flags().StringVar(&addr, "addr", ":8443", "Listen address")
cmd.Flags().StringVar(&tlsCert, "tls-cert", fmt.Sprintf("/etc/letsencrypt/live/%s/fullchain.pem", fqdn), "TLS certificate file (Let's Encrypt fullchain)")
cmd.Flags().StringVar(&tlsKey, "tls-key", fmt.Sprintf("/etc/letsencrypt/live/%s/privkey.pem", fqdn), "TLS private key file")
cmd.Flags().StringVar(&k3sServer, "k3s-server", "", "k3s API server URL written into returned kubeconfigs (required, e.g. https://k3s.example.com:6443)")
cmd.Flags().StringVar(&serverCACert, "server-ca-cert", "/var/lib/rancher/k3s/server/tls/server-ca.crt", "k3s server CA certificate (embedded in returned kubeconfig)")
cmd.Flags().StringVar(&clientCACert, "client-ca-cert", "/var/lib/rancher/k3s/server/tls/client-ca.crt", "k3s client CA certificate (signs user certs)")
cmd.Flags().StringVar(&clientCAKey, "client-ca-key", "/var/lib/rancher/k3s/server/tls/client-ca.key", "k3s client CA key")
cmd.Flags().DurationVar(&certDuration, "cert-duration", 24*time.Hour, "Validity period of generated client certificates")
cmd.Flags().DurationVar(&loginCertDuration, "login-cert-duration", 168*time.Hour, "Validity period of certificates issued by 'ward login' (default 7 days)")
cmd.Flags().StringVar(&clusterName, "cluster-name", firstLabel(domain), "Cluster/context name written into generated kubeconfigs")
cmd.Flags().StringVar(&ldapDomain, "domain", domain, "Domain for LDAP SRV discovery and Kerberos realm derivation")
cmd.Flags().BoolVar(&ldapOn, "ldap", false, "Enable LDAP authentication (auto-discovered via DNS SRV)")
cmd.Flags().StringVar(&ldapURI, "ldap-uri", "", "LDAP server URI, e.g. ldaps://ldap.example.com (implies --ldap; overrides DNS SRV)")
cmd.Flags().StringVar(&ldapBindDN, "ldap-bind-dn", os.Getenv("WARD_LDAP_BIND_DN"), "LDAP bind DN for search (default: anonymous; env: WARD_LDAP_BIND_DN)")
// LDAP bind password is read exclusively from $WARD_LDAP_BIND_PASSWORD to avoid
// exposure in process listings.
cmd.Flags().BoolVar(&kerberosOn, "kerberos", false, "Enable Kerberos SPNEGO authentication (Authorization: Negotiate)")
cmd.Flags().StringVar(&keytabPath, "keytab", "/etc/krb5.keytab", "Kerberos service keytab path")
cmd.Flags().StringVar(&spn, "spn", "HTTP/"+fqdn, fmt.Sprintf("Kerberos service principal name (SPN) (default realm %s — create with: kadmin: addprinc -randkey HTTP/%s@%s)", realm, fqdn, realm))
cmd.Flags().StringVar(&htpasswdPath, "htpasswd", "", "Path to an Apache-compatible htpasswd file (bcrypt recommended: htpasswd -B -c file user)")
cmd.Flags().BoolVar(&debugFlag, "debug", os.Getenv("WARD_DEBUG") != "", "Enable verbose debug logging (also: $WARD_DEBUG=1)")
return cmd
}
// detectFQDN returns the fully-qualified domain name of the local host,
// falling back to the short hostname if DNS resolution fails.
func detectFQDN() string {
hostname, err := os.Hostname()
if err != nil {
return "localhost"
}
if strings.Contains(hostname, ".") {
return hostname
}
addrs, err := net.LookupHost(hostname)
if err != nil || len(addrs) == 0 {
return hostname
}
names, err := net.LookupAddr(addrs[0])
if err != nil || len(names) == 0 {
return hostname
}
return strings.TrimSuffix(names[0], ".")
}
// domainPart strips the first label from a FQDN.
// "host.example.com" → "example.com"
func domainPart(fqdn string) string {
if idx := strings.IndexByte(fqdn, '.'); idx >= 0 {
return fqdn[idx+1:]
}
return fqdn
}
// firstLabel returns the first dot-separated label of a domain name.
// "example.com" → "example"
func firstLabel(domain string) string {
if idx := strings.IndexByte(domain, '.'); idx >= 0 {
return domain[:idx]
}
return domain
}