Files
ward/main.go

327 lines
11 KiB
Go

package main
import (
"crypto/tls"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/spf13/cobra"
)
// dbg is a package-level debug logger, active only when --debug is passed.
// Defaults to discarding output; serve command points it at stderr when --debug is set.
var dbg = log.New(io.Discard, "[ward] ", 0)
func main() {
root := &cobra.Command{
Use: "ward",
Short: "Kubernetes credential gateway",
SilenceUsage: true,
SilenceErrors: true,
}
root.AddCommand(newServeCmd())
root.AddCommand(newCredentialCmd())
if err := root.Execute(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func newServeCmd() *cobra.Command {
fqdn := detectFQDN()
domain := domainPart(fqdn)
realm := strings.ToUpper(domain)
var (
addr string
tlsCert string
tlsKey string
k3sServer string
serverCACert string
clientCACert string
clientCAKey string
certDuration time.Duration
clusterName string
ldapDomain string
ldapOn bool
ldapURI string
ldapBindDN string
ldapBindPassword string
kerberosOn bool
keytabPath string
spn string
htpasswdPath string
debugFlag bool
)
cmd := &cobra.Command{
Use: "serve",
Short: "Run the ward authentication server",
RunE: func(cmd *cobra.Command, args []string) error {
if debugFlag {
dbg = log.New(os.Stderr, "[ward] ", log.Ltime)
dbg.Print("debug logging enabled")
}
if k3sServer == "" {
return fmt.Errorf("--k3s-server is required (e.g. https://k3s.example.com:6443)")
}
// ── LDAP ──────────────────────────────────────────────────────────────────
var ldapAuth *LDAPAuth
if ldapURI != "" || ldapOn {
var la *LDAPAuth
var err error
if ldapURI != "" {
la, err = NewLDAPAuthFromURI(ldapURI, ldapDomain, ldapBindDN, ldapBindPassword)
} else {
la, err = NewLDAPAuth(ldapDomain, ldapBindDN, ldapBindPassword)
}
if err != nil {
return fmt.Errorf("LDAP: %w", err)
}
ldapAuth = la
anon := ldapBindDN == ""
log.Printf("LDAP: %s:%d (TLS=%v) domain=%s anon=%v", la.host, la.port, la.useTLS, ldapDomain, anon)
}
if ldapAuth == nil && !kerberosOn && htpasswdPath == "" {
return fmt.Errorf("no authentication providers configured: use at least one of --ldap, --ldap-uri, --kerberos, or --htpasswd")
}
// ── Kerberos ──────────────────────────────────────────────────────────────
var krbAuth *KerberosAuth
if kerberosOn {
ka, err := NewKerberosAuth(keytabPath, spn)
if err != nil {
return fmt.Errorf("Kerberos: %w", err)
}
krbAuth = ka
log.Printf("Kerberos: keytab=%s SPN=%s realm=%s", keytabPath, spn, realm)
}
// ── htpasswd ──────────────────────────────────────────────────────────────
var htpasswdAuth *HtpasswdAuth
if htpasswdPath != "" {
ha, err := NewHtpasswdAuth(htpasswdPath)
if err != nil {
return fmt.Errorf("htpasswd: %w", err)
}
htpasswdAuth = ha
log.Printf("htpasswd: %s (%d entries)", htpasswdPath, ha.Len())
sighup := make(chan os.Signal, 1)
signal.Notify(sighup, syscall.SIGHUP)
go func() {
for range sighup {
if err := htpasswdAuth.Reload(); err != nil {
log.Printf("SIGHUP: htpasswd reload failed: %v", err)
} else {
log.Printf("SIGHUP: htpasswd reloaded (%d entries)", htpasswdAuth.Len())
}
}
}()
}
// ── Handler ───────────────────────────────────────────────────────────────
h, err := NewHandler(ldapAuth, krbAuth, htpasswdAuth, &CertConfig{
ServerURL: k3sServer,
ServerCACert: serverCACert,
ClientCACert: clientCACert,
ClientCAKey: clientCAKey,
Duration: certDuration,
ClusterName: clusterName,
})
if err != nil {
return fmt.Errorf("handler: %w", err)
}
// ── TLS ───────────────────────────────────────────────────────────────────
if _, err := tls.LoadX509KeyPair(tlsCert, tlsKey); err != nil {
return fmt.Errorf("TLS: loading certificate: %w", err)
}
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(tlsCert, tlsKey)
if err != nil {
return nil, fmt.Errorf("reloading TLS cert: %w", err)
}
return &cert, nil
},
}
ln, err := tls.Listen("tcp", addr, tlsConfig)
if err != nil {
return fmt.Errorf("listen %s: %w", addr, err)
}
mux := http.NewServeMux()
mux.HandleFunc("/kubeconfig", h.ServeHTTP)
mux.HandleFunc("/credential", h.ServeCredential)
mux.HandleFunc("/bootstrap", h.ServeBootstrap)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprintf(w, "ward — Kubernetes credential gateway\n\n"+
" GET /bootstrap kubeconfig with exec plugin pre-wired (no auth required)\n"+
" GET /credential ExecCredential JSON for kubectl exec plugin\n"+
" GET /kubeconfig kubeconfig with embedded client certificate\n")
})
srv := &http.Server{
Handler: mux,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
}
log.Printf("ward listening on %s (cert-duration=%s)", addr, certDuration)
return srv.Serve(ln)
},
}
cmd.Flags().StringVar(&addr, "addr", ":8443", "Listen address")
cmd.Flags().StringVar(&tlsCert, "tls-cert", fmt.Sprintf("/etc/letsencrypt/live/%s/fullchain.pem", fqdn), "TLS certificate file (Let's Encrypt fullchain)")
cmd.Flags().StringVar(&tlsKey, "tls-key", fmt.Sprintf("/etc/letsencrypt/live/%s/privkey.pem", fqdn), "TLS private key file")
cmd.Flags().StringVar(&k3sServer, "k3s-server", "", "k3s API server URL written into returned kubeconfigs (required, e.g. https://k3s.example.com:6443)")
cmd.Flags().StringVar(&serverCACert, "server-ca-cert", "/var/lib/rancher/k3s/server/tls/server-ca.crt", "k3s server CA certificate (embedded in returned kubeconfig)")
cmd.Flags().StringVar(&clientCACert, "client-ca-cert", "/var/lib/rancher/k3s/server/tls/client-ca.crt", "k3s client CA certificate (signs user certs)")
cmd.Flags().StringVar(&clientCAKey, "client-ca-key", "/var/lib/rancher/k3s/server/tls/client-ca.key", "k3s client CA key")
cmd.Flags().DurationVar(&certDuration, "cert-duration", 24*time.Hour, "Validity period of generated client certificates")
cmd.Flags().StringVar(&clusterName, "cluster-name", firstLabel(domain), "Cluster/context name written into generated kubeconfigs")
cmd.Flags().StringVar(&ldapDomain, "domain", domain, "Domain for LDAP SRV discovery and Kerberos realm derivation")
cmd.Flags().BoolVar(&ldapOn, "ldap", false, "Enable LDAP authentication (auto-discovered via DNS SRV)")
cmd.Flags().StringVar(&ldapURI, "ldap-uri", "", "LDAP server URI, e.g. ldaps://ldap.example.com (implies --ldap; overrides DNS SRV)")
cmd.Flags().StringVar(&ldapBindDN, "ldap-bind-dn", os.Getenv("WARD_LDAP_BIND_DN"), "LDAP bind DN for search (default: anonymous; env: WARD_LDAP_BIND_DN)")
cmd.Flags().StringVar(&ldapBindPassword, "ldap-bind-password", os.Getenv("WARD_LDAP_BIND_PASSWORD"), "LDAP bind password (env: WARD_LDAP_BIND_PASSWORD; caution: visible in ps output)")
cmd.Flags().BoolVar(&kerberosOn, "kerberos", false, "Enable Kerberos SPNEGO authentication (Authorization: Negotiate)")
cmd.Flags().StringVar(&keytabPath, "keytab", "/etc/krb5.keytab", "Kerberos service keytab path")
cmd.Flags().StringVar(&spn, "spn", "HTTP/"+fqdn, fmt.Sprintf("Kerberos service principal name (SPN) (default realm %s — create with: kadmin: addprinc -randkey HTTP/%s@%s)", realm, fqdn, realm))
cmd.Flags().StringVar(&htpasswdPath, "htpasswd", "", "Path to an Apache-compatible htpasswd file (bcrypt recommended: htpasswd -B -c file user)")
cmd.Flags().BoolVar(&debugFlag, "debug", os.Getenv("WARD_DEBUG") != "", "Enable verbose debug logging (also: $WARD_DEBUG=1)")
return cmd
}
func newCredentialCmd() *cobra.Command {
var (
server string
username string
noKerberos bool
noCache bool
debugFlag bool
)
cmd := &cobra.Command{
Use: "credential",
Short: "kubectl exec credential plugin — fetch an ExecCredential from ward",
Long: `Acts as a kubectl exec credential plugin. Fetches an ExecCredential JSON
from the ward server and prints it to stdout for kubectl to consume.
Authentication priority:
1. Kerberos SPNEGO using the active credential cache (from kinit)
2. Basic auth — prompts for password, or reads $WARD_PASSWORD
Credentials are cached in ~/.cache/ward/ and reused until 5 minutes
before expiry, so kubectl invocations are fast after the first call.
Debug output goes to stderr (kubectl surfaces this to the terminal):
WARD_DEBUG=1 kubectl get nodes`,
RunE: func(cmd *cobra.Command, args []string) error {
if server == "" {
return fmt.Errorf("--server is required")
}
logf := func(format string, a ...interface{}) {
if debugFlag {
fmt.Fprintf(os.Stderr, "[ward] "+format+"\n", a...)
}
}
if !noCache {
if ec, ok := credReadCache(server, logf); ok {
return credPrint(ec)
}
}
ec, err := credFetch(server, username, noKerberos, logf)
if err != nil {
return err
}
if !noCache {
credWriteCache(server, ec, logf)
}
return credPrint(ec)
},
}
cmd.Flags().StringVar(&server, "server", "", "ward server URL (required)")
cmd.Flags().StringVar(&username, "username", "", "username for Basic auth fallback (default: $USER)")
cmd.Flags().BoolVar(&noKerberos, "no-kerberos", false, "skip Kerberos; always use Basic auth")
cmd.Flags().BoolVar(&noCache, "no-cache", false, "bypass local cache; always fetch a fresh credential")
cmd.Flags().BoolVar(&debugFlag, "debug", os.Getenv("WARD_DEBUG") != "", "verbose debug output to stderr (also: $WARD_DEBUG=1)")
return cmd
}
// detectFQDN returns the fully-qualified domain name of the local host,
// falling back to the short hostname if DNS resolution fails.
func detectFQDN() string {
hostname, err := os.Hostname()
if err != nil {
return "localhost"
}
if strings.Contains(hostname, ".") {
return hostname
}
addrs, err := net.LookupHost(hostname)
if err != nil || len(addrs) == 0 {
return hostname
}
names, err := net.LookupAddr(addrs[0])
if err != nil || len(names) == 0 {
return hostname
}
return strings.TrimSuffix(names[0], ".")
}
// domainPart strips the first label from a FQDN.
// "host.example.com" → "example.com"
func domainPart(fqdn string) string {
if idx := strings.IndexByte(fqdn, '.'); idx >= 0 {
return fqdn[idx+1:]
}
return fqdn
}
// firstLabel returns the first dot-separated label of a domain name.
// "example.com" → "example"
func firstLabel(domain string) string {
if idx := strings.IndexByte(domain, '.'); idx >= 0 {
return domain[:idx]
}
return domain
}