253 lines
9.4 KiB
Go
253 lines
9.4 KiB
Go
package cmd
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/spf13/cobra"
|
|
|
|
"git.shee.sh/james/ward/pkg/auth"
|
|
"git.shee.sh/james/ward/pkg/cert"
|
|
"git.shee.sh/james/ward/pkg/handler"
|
|
)
|
|
|
|
func newServeCmd() *cobra.Command {
|
|
fqdn := detectFQDN()
|
|
domain := domainPart(fqdn)
|
|
realm := strings.ToUpper(domain)
|
|
|
|
var (
|
|
addr string
|
|
tlsCert string
|
|
tlsKey string
|
|
k3sServer string
|
|
serverCACert string
|
|
clientCACert string
|
|
clientCAKey string
|
|
certDuration time.Duration
|
|
clusterName string
|
|
|
|
ldapDomain string
|
|
ldapOn bool
|
|
ldapURI string
|
|
ldapBindDN string
|
|
|
|
loginCertDuration time.Duration
|
|
|
|
kerberosOn bool
|
|
keytabPath string
|
|
spn string
|
|
|
|
htpasswdPath string
|
|
debugFlag bool
|
|
)
|
|
|
|
cmd := &cobra.Command{
|
|
Use: "serve",
|
|
Short: "Run the ward authentication server",
|
|
RunE: func(cmd *cobra.Command, args []string) error {
|
|
dbg := log.New(io.Discard, "[ward] ", 0)
|
|
if debugFlag {
|
|
dbg = log.New(os.Stderr, "[ward] ", log.Ltime)
|
|
dbg.Print("debug logging enabled")
|
|
}
|
|
|
|
if k3sServer == "" {
|
|
return fmt.Errorf("--k3s-server is required (e.g. https://k3s.example.com:6443)")
|
|
}
|
|
|
|
// ── LDAP ──────────────────────────────────────────────────────────────────
|
|
ldapBindPassword := os.Getenv("WARD_LDAP_BIND_PASSWORD")
|
|
var ldapAuth *auth.LDAPAuth
|
|
if ldapURI != "" || ldapOn {
|
|
var la *auth.LDAPAuth
|
|
var err error
|
|
if ldapURI != "" {
|
|
la, err = auth.NewLDAPAuthFromURI(ldapURI, ldapDomain, ldapBindDN, ldapBindPassword, dbg)
|
|
} else {
|
|
la, err = auth.NewLDAPAuth(ldapDomain, ldapBindDN, ldapBindPassword, dbg)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("LDAP: %w", err)
|
|
}
|
|
ldapAuth = la
|
|
anon := ldapBindDN == ""
|
|
log.Printf("LDAP: %s:%d (TLS=%v) domain=%s anon=%v", la.Host(), la.Port(), la.UseTLS(), ldapDomain, anon)
|
|
}
|
|
|
|
if ldapAuth == nil && !kerberosOn && htpasswdPath == "" {
|
|
return fmt.Errorf("no authentication providers configured: use at least one of --ldap, --ldap-uri, --kerberos, or --htpasswd")
|
|
}
|
|
|
|
// ── Kerberos ──────────────────────────────────────────────────────────────
|
|
var krbAuth *auth.KerberosAuth
|
|
if kerberosOn {
|
|
ka, err := auth.NewKerberosAuth(keytabPath, spn)
|
|
if err != nil {
|
|
return fmt.Errorf("Kerberos: %w", err)
|
|
}
|
|
krbAuth = ka
|
|
log.Printf("Kerberos: keytab=%s SPN=%s realm=%s", keytabPath, spn, realm)
|
|
}
|
|
|
|
// ── htpasswd ──────────────────────────────────────────────────────────────
|
|
var htpasswdAuth *auth.HtpasswdAuth
|
|
if htpasswdPath != "" {
|
|
ha, err := auth.NewHtpasswdAuth(htpasswdPath)
|
|
if err != nil {
|
|
return fmt.Errorf("htpasswd: %w", err)
|
|
}
|
|
htpasswdAuth = ha
|
|
log.Printf("htpasswd: %s (%d entries)", htpasswdPath, ha.Len())
|
|
|
|
sighup := make(chan os.Signal, 1)
|
|
signal.Notify(sighup, syscall.SIGHUP)
|
|
go func() {
|
|
for range sighup {
|
|
if err := htpasswdAuth.Reload(); err != nil {
|
|
log.Printf("SIGHUP: htpasswd reload failed: %v", err)
|
|
} else {
|
|
log.Printf("SIGHUP: htpasswd reloaded (%d entries)", htpasswdAuth.Len())
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// ── Handler ───────────────────────────────────────────────────────────────
|
|
h, err := handler.NewHandler(ldapAuth, krbAuth, htpasswdAuth, &cert.CertConfig{
|
|
ServerURL: k3sServer,
|
|
ServerCACert: serverCACert,
|
|
ClientCACert: clientCACert,
|
|
ClientCAKey: clientCAKey,
|
|
Duration: certDuration,
|
|
LoginDuration: loginCertDuration,
|
|
ClusterName: clusterName,
|
|
}, dbg)
|
|
if err != nil {
|
|
return fmt.Errorf("handler: %w", err)
|
|
}
|
|
|
|
// ── TLS ───────────────────────────────────────────────────────────────────
|
|
if _, err := tls.LoadX509KeyPair(tlsCert, tlsKey); err != nil {
|
|
return fmt.Errorf("TLS: loading certificate: %w", err)
|
|
}
|
|
tlsConfig := &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
c, err := tls.LoadX509KeyPair(tlsCert, tlsKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reloading TLS cert: %w", err)
|
|
}
|
|
return &c, nil
|
|
},
|
|
}
|
|
|
|
ln, err := tls.Listen("tcp", addr, tlsConfig)
|
|
if err != nil {
|
|
return fmt.Errorf("listen %s: %w", addr, err)
|
|
}
|
|
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/kubeconfig", h.ServeHTTP)
|
|
mux.HandleFunc("/credential", h.ServeCredential)
|
|
mux.HandleFunc("/bootstrap", h.ServeBootstrap)
|
|
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/" {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
fmt.Fprintf(w, "ward — Kubernetes credential gateway\n\n"+
|
|
" GET /bootstrap kubeconfig with exec plugin pre-wired (no auth required)\n"+
|
|
" GET /credential ExecCredential JSON for kubectl exec plugin\n"+
|
|
" GET /kubeconfig kubeconfig with embedded client certificate\n")
|
|
})
|
|
|
|
srv := &http.Server{
|
|
Handler: mux,
|
|
ReadTimeout: 30 * time.Second,
|
|
WriteTimeout: 30 * time.Second,
|
|
}
|
|
|
|
log.Printf("ward listening on %s (cert-duration=%s)", addr, certDuration)
|
|
return srv.Serve(ln)
|
|
},
|
|
}
|
|
|
|
cmd.Flags().StringVar(&addr, "addr", ":8443", "Listen address")
|
|
cmd.Flags().StringVar(&tlsCert, "tls-cert", fmt.Sprintf("/etc/letsencrypt/live/%s/fullchain.pem", fqdn), "TLS certificate file (Let's Encrypt fullchain)")
|
|
cmd.Flags().StringVar(&tlsKey, "tls-key", fmt.Sprintf("/etc/letsencrypt/live/%s/privkey.pem", fqdn), "TLS private key file")
|
|
cmd.Flags().StringVar(&k3sServer, "k3s-server", "", "k3s API server URL written into returned kubeconfigs (required, e.g. https://k3s.example.com:6443)")
|
|
cmd.Flags().StringVar(&serverCACert, "server-ca-cert", "/var/lib/rancher/k3s/server/tls/server-ca.crt", "k3s server CA certificate (embedded in returned kubeconfig)")
|
|
cmd.Flags().StringVar(&clientCACert, "client-ca-cert", "/var/lib/rancher/k3s/server/tls/client-ca.crt", "k3s client CA certificate (signs user certs)")
|
|
cmd.Flags().StringVar(&clientCAKey, "client-ca-key", "/var/lib/rancher/k3s/server/tls/client-ca.key", "k3s client CA key")
|
|
cmd.Flags().DurationVar(&certDuration, "cert-duration", 24*time.Hour, "Validity period of generated client certificates")
|
|
cmd.Flags().DurationVar(&loginCertDuration, "login-cert-duration", 168*time.Hour, "Validity period of certificates issued by 'ward login' (default 7 days)")
|
|
cmd.Flags().StringVar(&clusterName, "cluster-name", firstLabel(domain), "Cluster/context name written into generated kubeconfigs")
|
|
|
|
cmd.Flags().StringVar(&ldapDomain, "domain", domain, "Domain for LDAP SRV discovery and Kerberos realm derivation")
|
|
cmd.Flags().BoolVar(&ldapOn, "ldap", false, "Enable LDAP authentication (auto-discovered via DNS SRV)")
|
|
cmd.Flags().StringVar(&ldapURI, "ldap-uri", "", "LDAP server URI, e.g. ldaps://ldap.example.com (implies --ldap; overrides DNS SRV)")
|
|
cmd.Flags().StringVar(&ldapBindDN, "ldap-bind-dn", os.Getenv("WARD_LDAP_BIND_DN"), "LDAP bind DN for search (default: anonymous; env: WARD_LDAP_BIND_DN)")
|
|
// LDAP bind password is read exclusively from $WARD_LDAP_BIND_PASSWORD to avoid
|
|
// exposure in process listings.
|
|
|
|
cmd.Flags().BoolVar(&kerberosOn, "kerberos", false, "Enable Kerberos SPNEGO authentication (Authorization: Negotiate)")
|
|
cmd.Flags().StringVar(&keytabPath, "keytab", "/etc/krb5.keytab", "Kerberos service keytab path")
|
|
cmd.Flags().StringVar(&spn, "spn", "HTTP/"+fqdn, fmt.Sprintf("Kerberos service principal name (SPN) (default realm %s — create with: kadmin: addprinc -randkey HTTP/%s@%s)", realm, fqdn, realm))
|
|
|
|
cmd.Flags().StringVar(&htpasswdPath, "htpasswd", "", "Path to an Apache-compatible htpasswd file (bcrypt recommended: htpasswd -B -c file user)")
|
|
|
|
cmd.Flags().BoolVar(&debugFlag, "debug", os.Getenv("WARD_DEBUG") != "", "Enable verbose debug logging (also: $WARD_DEBUG=1)")
|
|
|
|
return cmd
|
|
}
|
|
|
|
// detectFQDN returns the fully-qualified domain name of the local host,
|
|
// falling back to the short hostname if DNS resolution fails.
|
|
func detectFQDN() string {
|
|
hostname, err := os.Hostname()
|
|
if err != nil {
|
|
return "localhost"
|
|
}
|
|
if strings.Contains(hostname, ".") {
|
|
return hostname
|
|
}
|
|
addrs, err := net.LookupHost(hostname)
|
|
if err != nil || len(addrs) == 0 {
|
|
return hostname
|
|
}
|
|
names, err := net.LookupAddr(addrs[0])
|
|
if err != nil || len(names) == 0 {
|
|
return hostname
|
|
}
|
|
return strings.TrimSuffix(names[0], ".")
|
|
}
|
|
|
|
// domainPart strips the first label from a FQDN.
|
|
// "host.example.com" → "example.com"
|
|
func domainPart(fqdn string) string {
|
|
if idx := strings.IndexByte(fqdn, '.'); idx >= 0 {
|
|
return fqdn[idx+1:]
|
|
}
|
|
return fqdn
|
|
}
|
|
|
|
// firstLabel returns the first dot-separated label of a domain name.
|
|
// "example.com" → "example"
|
|
func firstLabel(domain string) string {
|
|
if idx := strings.IndexByte(domain, '.'); idx >= 0 {
|
|
return domain[:idx]
|
|
}
|
|
return domain
|
|
}
|