2 Commits
v0.4.2 ... main

Author SHA1 Message Date
a0a7932f99 Refactor project layout
All checks were successful
Release / release (push) Successful in 1m38s
2026-04-01 13:16:06 +02:00
5d2a80bd30 Attempt to support unconfigured kerberos and weird Mac cred cache 2026-03-31 12:43:36 +02:00
14 changed files with 834 additions and 459 deletions

View File

@@ -1,4 +1,4 @@
package main package cmd
import ( import (
"cmp" "cmp"
@@ -8,69 +8,63 @@ import (
"io" "io"
"net/http" "net/http"
"os" "os"
"os/signal"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"syscall"
"time" "time"
"golang.org/x/term"
krb5client "github.com/jcmturner/gokrb5/v8/client" krb5client "github.com/jcmturner/gokrb5/v8/client"
"github.com/jcmturner/gokrb5/v8/config" "github.com/jcmturner/gokrb5/v8/config"
"github.com/jcmturner/gokrb5/v8/credentials" "github.com/jcmturner/gokrb5/v8/credentials"
"github.com/jcmturner/gokrb5/v8/spnego" "github.com/jcmturner/gokrb5/v8/spnego"
"git.shee.sh/james/ward/pkg/handler"
) )
func credFetch(server, username string, noKerberos bool, logf func(string, ...any)) (*ExecCredential, error) { func credFetch(server string, noKerberos bool, logf func(string, ...any)) (*handler.ExecCredential, error) {
url := strings.TrimRight(server, "/") + "/credential" url := strings.TrimRight(server, "/") + "/credential"
// ── Kerberos SPNEGO ─────────────────────────────────────────────────────── // ── Kerberos SPNEGO ───────────────────────────────────────────────────────
if !noKerberos { if !noKerberos {
body, err := credFetchKerberos(url, logf) body, err := credFetchKerberos(url, logf)
if err != nil { if err != nil {
logf("Kerberos: %v — falling back to Basic auth", err) logf("Kerberos: %v", err)
fmt.Fprintf(os.Stderr, "ward: Kerberos failed (%v); falling back to Basic auth\nhint: run 'kinit' to avoid the password prompt\n", err)
} else { } else {
return credParse(body) return credParse(body)
} }
} }
// ── Basic auth ──────────────────────────────────────────────────────────── return nil, fmt.Errorf("no valid credential — run 'ward login --server %s'", server)
if username == "" {
username = os.Getenv("USER")
}
password := os.Getenv("WARD_PASSWORD")
if password == "" {
var err error
password, err = credPromptPassword(username)
if err != nil {
return nil, err
}
} else {
logf("Basic: using password from $WARD_PASSWORD")
}
body, err := credFetchBasic(url, username, password, logf)
if err != nil {
return nil, err
}
return credParse(body)
} }
func credFetchKerberos(url string, logf func(string, ...any)) ([]byte, error) { func credFetchKerberos(url string, logf func(string, ...any)) ([]byte, error) {
krb5cfgPath := krb5ConfigPath() krb5cfgPath := krb5ConfigPath()
logf("Kerberos: loading config from %s", krb5cfgPath) logf("Kerberos: loading config from %s", krb5cfgPath)
krb5cfg, err := config.Load(krb5cfgPath) var krb5cfg *config.Config
if err != nil { if _, statErr := os.Stat(krb5cfgPath); os.IsNotExist(statErr) && os.Getenv("KRB5_CONFIG") == "" {
return nil, fmt.Errorf("loading krb5 config: %w", err) logf("Kerberos: %s not found, using default config (KDC discovery via DNS)", krb5cfgPath)
krb5cfg = config.New()
krb5cfg.LibDefaults.DNSLookupKDC = true
} else {
var err error
krb5cfg, err = config.Load(krb5cfgPath)
if err != nil {
return nil, fmt.Errorf("loading krb5 config: %w", err)
}
} }
ccPath := ccachePath() ccPath, err := ccachePath()
if err != nil {
return nil, err
}
logf("Kerberos: loading credential cache from %s", ccPath) logf("Kerberos: loading credential cache from %s", ccPath)
ccache, err := credentials.LoadCCache(ccPath) ccache, err := credentials.LoadCCache(ccPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("no credential cache (%w) — run 'kinit'", err) hint := "run 'kinit'"
if runtime.GOOS == "darwin" && os.Getenv("KRB5CCNAME") == "" {
hint = fmt.Sprintf("on macOS, run: kinit -c /tmp/krb5cc_%d", os.Getuid())
}
return nil, fmt.Errorf("no credential cache (%w) — %s", err, hint)
} }
cl, err := krb5client.NewFromCCache(ccache, krb5cfg, krb5client.DisablePAFXFAST(true)) cl, err := krb5client.NewFromCCache(ccache, krb5cfg, krb5client.DisablePAFXFAST(true))
@@ -133,8 +127,8 @@ func credFetchBasic(url, username, password string, logf func(string, ...any)) (
return body, nil return body, nil
} }
func credParse(body []byte) (*ExecCredential, error) { func credParse(body []byte) (*handler.ExecCredential, error) {
var ec ExecCredential var ec handler.ExecCredential
if err := json.Unmarshal(body, &ec); err != nil { if err := json.Unmarshal(body, &ec); err != nil {
return nil, fmt.Errorf("parsing server response: %w", err) return nil, fmt.Errorf("parsing server response: %w", err)
} }
@@ -144,7 +138,7 @@ func credParse(body []byte) (*ExecCredential, error) {
return &ec, nil return &ec, nil
} }
func credPrint(ec *ExecCredential) error { func credPrint(ec *handler.ExecCredential) error {
return json.NewEncoder(os.Stdout).Encode(ec) return json.NewEncoder(os.Stdout).Encode(ec)
} }
@@ -166,7 +160,7 @@ func credCacheFile(serverURL string) string {
return filepath.Join(credCacheDir(), fmt.Sprintf("%x.json", h)) return filepath.Join(credCacheDir(), fmt.Sprintf("%x.json", h))
} }
func credReadCache(serverURL string, logf func(string, ...any)) (*ExecCredential, bool) { func credReadCache(serverURL string, logf func(string, ...any)) (*handler.ExecCredential, bool) {
path := credCacheFile(serverURL) path := credCacheFile(serverURL)
logf("cache: checking %s", path) logf("cache: checking %s", path)
@@ -175,7 +169,7 @@ func credReadCache(serverURL string, logf func(string, ...any)) (*ExecCredential
logf("cache: miss (%v)", err) logf("cache: miss (%v)", err)
return nil, false return nil, false
} }
var ec ExecCredential var ec handler.ExecCredential
if err := json.Unmarshal(data, &ec); err != nil { if err := json.Unmarshal(data, &ec); err != nil {
logf("cache: corrupt, ignoring (%v)", err) logf("cache: corrupt, ignoring (%v)", err)
return nil, false return nil, false
@@ -199,7 +193,7 @@ func credReadCache(serverURL string, logf func(string, ...any)) (*ExecCredential
return &ec, true return &ec, true
} }
func credWriteCache(serverURL string, ec *ExecCredential, logf func(string, ...any)) { func credWriteCache(serverURL string, ec *handler.ExecCredential, logf func(string, ...any)) {
dir := credCacheDir() dir := credCacheDir()
if err := os.MkdirAll(dir, 0700); err != nil { if err := os.MkdirAll(dir, 0700); err != nil {
logf("cache: failed to create dir: %v", err) logf("cache: failed to create dir: %v", err)
@@ -224,47 +218,23 @@ func krb5ConfigPath() string {
return cmp.Or(os.Getenv("KRB5_CONFIG"), "/etc/krb5.conf") return cmp.Or(os.Getenv("KRB5_CONFIG"), "/etc/krb5.conf")
} }
// ccachePath returns the path to the active Kerberos credential cache. // ccachePath returns the path to the active Kerberos credential cache, or an
// Respects $KRB5CCNAME; strips the "FILE:" prefix if present. // error if $KRB5CCNAME names a non-file cache type that gokrb5 cannot read.
// Non-file ccache types (API:, KEYRING:, DIR:) are not supported by gokrb5 //
// and will produce an error when LoadCCache is called. // On macOS, kinit defaults to API: caches. Work around it with:
func ccachePath() string { //
// kinit -c /tmp/krb5cc_$(id -u)
func ccachePath() (string, error) {
if v := os.Getenv("KRB5CCNAME"); v != "" { if v := os.Getenv("KRB5CCNAME"); v != "" {
return strings.TrimPrefix(v, "FILE:") for _, prefix := range []string{"API:", "KEYRING:", "DIR:", "KCM:"} {
if strings.HasPrefix(v, prefix) {
return "", fmt.Errorf(
"credential cache type %s is not supported (gokrb5 requires a file-based cache)\n"+
"hint: re-run kinit with: kinit -c /tmp/krb5cc_%d",
prefix, os.Getuid())
}
}
return strings.TrimPrefix(v, "FILE:"), nil
} }
return fmt.Sprintf("/tmp/krb5cc_%d", os.Getuid()) return fmt.Sprintf("/tmp/krb5cc_%d", os.Getuid()), nil
}
// ── Password prompt ────────────────────────────────────────────────────────────
func credPromptPassword(username string) (string, error) {
terminal, err := os.OpenFile("/dev/tty", os.O_RDWR, 0)
if err != nil {
return "", fmt.Errorf(
"cannot open terminal and $WARD_PASSWORD is not set\n" +
"hint: run 'kinit' for Kerberos auth, or set $WARD_PASSWORD for non-interactive use")
}
oldState, err := term.MakeRaw(int(terminal.Fd()))
if err != nil {
return "", fmt.Errorf("setting terminal raw mode: %w", err)
}
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
go func() {
<-sigCh
term.Restore(int(terminal.Fd()), oldState)
os.Exit(1)
}()
fmt.Fprintf(terminal, "Password for %s: ", username)
pw, err := term.ReadPassword(int(terminal.Fd()))
fmt.Fprintf(terminal, "\r\n") // newline after the hidden input
signal.Stop(sigCh)
term.Restore(int(terminal.Fd()), oldState)
if err != nil {
return "", fmt.Errorf("reading password: %w", err)
}
return string(pw), nil
} }

68
cmd/credential.go Normal file
View File

@@ -0,0 +1,68 @@
package cmd
import (
"fmt"
"os"
"github.com/spf13/cobra"
)
func newCredentialCmd() *cobra.Command {
var (
server string
noKerberos bool
noCache bool
debugFlag bool
)
cmd := &cobra.Command{
Use: "credential",
Short: "kubectl exec credential plugin — serve a cached ExecCredential to kubectl",
Long: `Acts as a kubectl exec credential plugin. Returns a cached ExecCredential
JSON to kubectl. On a cache miss, silently attempts Kerberos SPNEGO; if that
also fails, exits with an error directing the user to run 'ward login'.
Run 'ward login' once to authenticate and populate the cache. After that,
kubectl works silently until the credential expires.
Debug output goes to stderr (kubectl surfaces this to the terminal):
WARD_DEBUG=1 kubectl get nodes`,
RunE: func(cmd *cobra.Command, args []string) error {
if server == "" {
return fmt.Errorf("--server is required")
}
server = normalizeServer(server)
logf := func(format string, a ...any) {
if debugFlag {
fmt.Fprintf(os.Stderr, "[ward] "+format+"\n", a...)
}
}
if !noCache {
if ec, ok := credReadCache(server, logf); ok {
return credPrint(ec)
}
}
ec, err := credFetch(server, noKerberos, logf)
if err != nil {
return err
}
if !noCache {
credWriteCache(server, ec, logf)
}
return credPrint(ec)
},
}
cmd.Flags().StringVar(&server, "server", "", "ward server URL (required)")
cmd.Flags().BoolVar(&noKerberos, "no-kerberos", false, "skip Kerberos SPNEGO")
cmd.Flags().BoolVar(&noCache, "no-cache", false, "bypass local cache; always fetch a fresh credential")
cmd.Flags().BoolVar(&debugFlag, "debug", os.Getenv("WARD_DEBUG") != "", "verbose debug output to stderr (also: $WARD_DEBUG=1)")
return cmd
}

139
cmd/login.go Normal file
View File

@@ -0,0 +1,139 @@
package cmd
import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"github.com/spf13/cobra"
"gopkg.in/yaml.v3"
"git.shee.sh/james/ward/pkg/handler"
"git.shee.sh/james/ward/pkg/kubeconfig"
)
func newLoginCmd() *cobra.Command {
var (
server string
username string
contextName string
noKerberos bool
noSetContext bool
debugFlag bool
)
cmd := &cobra.Command{
Use: "login",
Short: "Authenticate to ward and configure kubectl",
Long: `Authenticates to the ward server, caches a long-lived credential, and
updates ~/.kube/config with the cluster, user (exec plugin), and context.
After login, kubectl works silently until the credential expires, at
which point you run 'ward login' again.
Authentication priority:
1. Kerberos SPNEGO using the active credential cache (from kinit)
2. Password prompt`,
RunE: func(cmd *cobra.Command, args []string) error {
if server == "" {
return fmt.Errorf("--server is required")
}
server = normalizeServer(server)
logf := func(format string, a ...any) {
if debugFlag {
fmt.Fprintf(os.Stderr, "[ward] "+format+"\n", a...)
}
}
if username == "" {
username = os.Getenv("USER")
}
ec, err := loginFetch(server, username, noKerberos, logf)
if err != nil {
return err
}
credWriteCache(server, ec, logf)
bootstrap, err := fetchBootstrap(server, username, logf)
if err != nil {
return fmt.Errorf("fetching bootstrap kubeconfig: %w", err)
}
if contextName != "" {
kubeconfig.RenameContext(bootstrap, contextName)
}
contextApplied := bootstrap.CurrentContext
if err := kubeconfig.Merge(bootstrap, !noSetContext); err != nil {
return fmt.Errorf("updating kubeconfig: %w", err)
}
fmt.Fprintf(os.Stderr, "ward: logged in — context %q configured in %s\n",
contextApplied, kubeconfig.FilePath())
return nil
},
}
cmd.Flags().StringVar(&server, "server", "", "ward server URL (required)")
cmd.Flags().StringVar(&username, "username", "", "username (default: $USER)")
cmd.Flags().StringVar(&contextName, "context", "", "kubectl context/cluster name (overrides server default)")
cmd.Flags().BoolVar(&noKerberos, "no-kerberos", false, "skip Kerberos; use password auth")
cmd.Flags().BoolVar(&noSetContext, "no-set-context", false, "do not set as current-context")
cmd.Flags().BoolVar(&debugFlag, "debug", os.Getenv("WARD_DEBUG") != "", "verbose debug output to stderr")
return cmd
}
func loginFetch(server, username string, noKerberos bool, logf func(string, ...any)) (*handler.ExecCredential, error) {
loginURL := strings.TrimRight(server, "/") + "/credential?login=true"
if !noKerberos {
body, err := credFetchKerberos(loginURL, logf)
if err != nil {
logf("Kerberos: %v — falling back to password auth", err)
fmt.Fprintf(os.Stderr, "ward: Kerberos failed (%v); using password auth\nhint: run 'kinit' to avoid the password prompt next time\n", err)
} else {
return credParse(body)
}
}
password, err := promptPassword(username)
if err != nil {
return nil, err
}
body, err := credFetchBasic(loginURL, username, password, logf)
if err != nil {
return nil, err
}
return credParse(body)
}
func fetchBootstrap(server, username string, logf func(string, ...any)) (*kubeconfig.KubeConfig, error) {
bootstrapURL := strings.TrimRight(server, "/") + "/bootstrap?user=" + url.QueryEscape(username)
logf("HTTP: GET %s", bootstrapURL)
resp, err := http.Get(bootstrapURL) //nolint:gosec // URL derived from user-supplied server flag
if err != nil {
return nil, fmt.Errorf("HTTP request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("server returned %s", resp.Status)
}
var cfg kubeconfig.KubeConfig
if err := yaml.Unmarshal(body, &cfg); err != nil {
return nil, fmt.Errorf("parsing bootstrap kubeconfig: %w", err)
}
return &cfg, nil
}

46
cmd/root.go Normal file
View File

@@ -0,0 +1,46 @@
package cmd
import (
"fmt"
"os"
"strings"
"github.com/spf13/cobra"
)
// Execute builds the root command and runs it.
func Execute() {
root := newRootCmd()
if err := root.Execute(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func newRootCmd() *cobra.Command {
root := &cobra.Command{
Use: "ward",
Short: "Kubernetes credential gateway",
SilenceUsage: true,
SilenceErrors: true,
}
root.AddCommand(newServeCmd())
root.AddCommand(newCredentialCmd())
root.AddCommand(newLoginCmd())
return root
}
// normalizeServer ensures server has an https:// scheme and a port.
// Shared by the credential and login commands.
func normalizeServer(server string) string {
if !strings.Contains(server, "://") {
server = "https://" + server
}
parts := strings.Split(server, ":")
if len(parts) == 2 {
server += ":8443"
}
return server
}

252
cmd/serve.go Normal file
View File

@@ -0,0 +1,252 @@
package cmd
import (
"crypto/tls"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/spf13/cobra"
"git.shee.sh/james/ward/pkg/auth"
"git.shee.sh/james/ward/pkg/cert"
"git.shee.sh/james/ward/pkg/handler"
)
func newServeCmd() *cobra.Command {
fqdn := detectFQDN()
domain := domainPart(fqdn)
realm := strings.ToUpper(domain)
var (
addr string
tlsCert string
tlsKey string
k3sServer string
serverCACert string
clientCACert string
clientCAKey string
certDuration time.Duration
clusterName string
ldapDomain string
ldapOn bool
ldapURI string
ldapBindDN string
loginCertDuration time.Duration
kerberosOn bool
keytabPath string
spn string
htpasswdPath string
debugFlag bool
)
cmd := &cobra.Command{
Use: "serve",
Short: "Run the ward authentication server",
RunE: func(cmd *cobra.Command, args []string) error {
dbg := log.New(io.Discard, "[ward] ", 0)
if debugFlag {
dbg = log.New(os.Stderr, "[ward] ", log.Ltime)
dbg.Print("debug logging enabled")
}
if k3sServer == "" {
return fmt.Errorf("--k3s-server is required (e.g. https://k3s.example.com:6443)")
}
// ── LDAP ──────────────────────────────────────────────────────────────────
ldapBindPassword := os.Getenv("WARD_LDAP_BIND_PASSWORD")
var ldapAuth *auth.LDAPAuth
if ldapURI != "" || ldapOn {
var la *auth.LDAPAuth
var err error
if ldapURI != "" {
la, err = auth.NewLDAPAuthFromURI(ldapURI, ldapDomain, ldapBindDN, ldapBindPassword, dbg)
} else {
la, err = auth.NewLDAPAuth(ldapDomain, ldapBindDN, ldapBindPassword, dbg)
}
if err != nil {
return fmt.Errorf("LDAP: %w", err)
}
ldapAuth = la
anon := ldapBindDN == ""
log.Printf("LDAP: %s:%d (TLS=%v) domain=%s anon=%v", la.Host(), la.Port(), la.UseTLS(), ldapDomain, anon)
}
if ldapAuth == nil && !kerberosOn && htpasswdPath == "" {
return fmt.Errorf("no authentication providers configured: use at least one of --ldap, --ldap-uri, --kerberos, or --htpasswd")
}
// ── Kerberos ──────────────────────────────────────────────────────────────
var krbAuth *auth.KerberosAuth
if kerberosOn {
ka, err := auth.NewKerberosAuth(keytabPath, spn)
if err != nil {
return fmt.Errorf("Kerberos: %w", err)
}
krbAuth = ka
log.Printf("Kerberos: keytab=%s SPN=%s realm=%s", keytabPath, spn, realm)
}
// ── htpasswd ──────────────────────────────────────────────────────────────
var htpasswdAuth *auth.HtpasswdAuth
if htpasswdPath != "" {
ha, err := auth.NewHtpasswdAuth(htpasswdPath)
if err != nil {
return fmt.Errorf("htpasswd: %w", err)
}
htpasswdAuth = ha
log.Printf("htpasswd: %s (%d entries)", htpasswdPath, ha.Len())
sighup := make(chan os.Signal, 1)
signal.Notify(sighup, syscall.SIGHUP)
go func() {
for range sighup {
if err := htpasswdAuth.Reload(); err != nil {
log.Printf("SIGHUP: htpasswd reload failed: %v", err)
} else {
log.Printf("SIGHUP: htpasswd reloaded (%d entries)", htpasswdAuth.Len())
}
}
}()
}
// ── Handler ───────────────────────────────────────────────────────────────
h, err := handler.NewHandler(ldapAuth, krbAuth, htpasswdAuth, &cert.CertConfig{
ServerURL: k3sServer,
ServerCACert: serverCACert,
ClientCACert: clientCACert,
ClientCAKey: clientCAKey,
Duration: certDuration,
LoginDuration: loginCertDuration,
ClusterName: clusterName,
}, dbg)
if err != nil {
return fmt.Errorf("handler: %w", err)
}
// ── TLS ───────────────────────────────────────────────────────────────────
if _, err := tls.LoadX509KeyPair(tlsCert, tlsKey); err != nil {
return fmt.Errorf("TLS: loading certificate: %w", err)
}
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
c, err := tls.LoadX509KeyPair(tlsCert, tlsKey)
if err != nil {
return nil, fmt.Errorf("reloading TLS cert: %w", err)
}
return &c, nil
},
}
ln, err := tls.Listen("tcp", addr, tlsConfig)
if err != nil {
return fmt.Errorf("listen %s: %w", addr, err)
}
mux := http.NewServeMux()
mux.HandleFunc("/kubeconfig", h.ServeHTTP)
mux.HandleFunc("/credential", h.ServeCredential)
mux.HandleFunc("/bootstrap", h.ServeBootstrap)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprintf(w, "ward — Kubernetes credential gateway\n\n"+
" GET /bootstrap kubeconfig with exec plugin pre-wired (no auth required)\n"+
" GET /credential ExecCredential JSON for kubectl exec plugin\n"+
" GET /kubeconfig kubeconfig with embedded client certificate\n")
})
srv := &http.Server{
Handler: mux,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
}
log.Printf("ward listening on %s (cert-duration=%s)", addr, certDuration)
return srv.Serve(ln)
},
}
cmd.Flags().StringVar(&addr, "addr", ":8443", "Listen address")
cmd.Flags().StringVar(&tlsCert, "tls-cert", fmt.Sprintf("/etc/letsencrypt/live/%s/fullchain.pem", fqdn), "TLS certificate file (Let's Encrypt fullchain)")
cmd.Flags().StringVar(&tlsKey, "tls-key", fmt.Sprintf("/etc/letsencrypt/live/%s/privkey.pem", fqdn), "TLS private key file")
cmd.Flags().StringVar(&k3sServer, "k3s-server", "", "k3s API server URL written into returned kubeconfigs (required, e.g. https://k3s.example.com:6443)")
cmd.Flags().StringVar(&serverCACert, "server-ca-cert", "/var/lib/rancher/k3s/server/tls/server-ca.crt", "k3s server CA certificate (embedded in returned kubeconfig)")
cmd.Flags().StringVar(&clientCACert, "client-ca-cert", "/var/lib/rancher/k3s/server/tls/client-ca.crt", "k3s client CA certificate (signs user certs)")
cmd.Flags().StringVar(&clientCAKey, "client-ca-key", "/var/lib/rancher/k3s/server/tls/client-ca.key", "k3s client CA key")
cmd.Flags().DurationVar(&certDuration, "cert-duration", 24*time.Hour, "Validity period of generated client certificates")
cmd.Flags().DurationVar(&loginCertDuration, "login-cert-duration", 168*time.Hour, "Validity period of certificates issued by 'ward login' (default 7 days)")
cmd.Flags().StringVar(&clusterName, "cluster-name", firstLabel(domain), "Cluster/context name written into generated kubeconfigs")
cmd.Flags().StringVar(&ldapDomain, "domain", domain, "Domain for LDAP SRV discovery and Kerberos realm derivation")
cmd.Flags().BoolVar(&ldapOn, "ldap", false, "Enable LDAP authentication (auto-discovered via DNS SRV)")
cmd.Flags().StringVar(&ldapURI, "ldap-uri", "", "LDAP server URI, e.g. ldaps://ldap.example.com (implies --ldap; overrides DNS SRV)")
cmd.Flags().StringVar(&ldapBindDN, "ldap-bind-dn", os.Getenv("WARD_LDAP_BIND_DN"), "LDAP bind DN for search (default: anonymous; env: WARD_LDAP_BIND_DN)")
// LDAP bind password is read exclusively from $WARD_LDAP_BIND_PASSWORD to avoid
// exposure in process listings.
cmd.Flags().BoolVar(&kerberosOn, "kerberos", false, "Enable Kerberos SPNEGO authentication (Authorization: Negotiate)")
cmd.Flags().StringVar(&keytabPath, "keytab", "/etc/krb5.keytab", "Kerberos service keytab path")
cmd.Flags().StringVar(&spn, "spn", "HTTP/"+fqdn, fmt.Sprintf("Kerberos service principal name (SPN) (default realm %s — create with: kadmin: addprinc -randkey HTTP/%s@%s)", realm, fqdn, realm))
cmd.Flags().StringVar(&htpasswdPath, "htpasswd", "", "Path to an Apache-compatible htpasswd file (bcrypt recommended: htpasswd -B -c file user)")
cmd.Flags().BoolVar(&debugFlag, "debug", os.Getenv("WARD_DEBUG") != "", "Enable verbose debug logging (also: $WARD_DEBUG=1)")
return cmd
}
// detectFQDN returns the fully-qualified domain name of the local host,
// falling back to the short hostname if DNS resolution fails.
func detectFQDN() string {
hostname, err := os.Hostname()
if err != nil {
return "localhost"
}
if strings.Contains(hostname, ".") {
return hostname
}
addrs, err := net.LookupHost(hostname)
if err != nil || len(addrs) == 0 {
return hostname
}
names, err := net.LookupAddr(addrs[0])
if err != nil || len(names) == 0 {
return hostname
}
return strings.TrimSuffix(names[0], ".")
}
// domainPart strips the first label from a FQDN.
// "host.example.com" → "example.com"
func domainPart(fqdn string) string {
if idx := strings.IndexByte(fqdn, '.'); idx >= 0 {
return fqdn[idx+1:]
}
return fqdn
}
// firstLabel returns the first dot-separated label of a domain name.
// "example.com" → "example"
func firstLabel(domain string) string {
if idx := strings.IndexByte(domain, '.'); idx >= 0 {
return domain[:idx]
}
return domain
}

40
cmd/tty.go Normal file
View File

@@ -0,0 +1,40 @@
package cmd
import (
"fmt"
"os"
"os/signal"
"syscall"
"golang.org/x/term"
)
func promptPassword(username string) (string, error) {
terminal, err := os.OpenFile("/dev/tty", os.O_RDWR, 0)
if err != nil {
return "", fmt.Errorf("cannot open terminal for password prompt\nhint: run 'kinit' for Kerberos auth")
}
oldState, err := term.MakeRaw(int(terminal.Fd()))
if err != nil {
return "", fmt.Errorf("setting terminal raw mode: %w", err)
}
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
go func() {
<-sigCh
term.Restore(int(terminal.Fd()), oldState)
os.Exit(1)
}()
fmt.Fprintf(terminal, "Password for %s: ", username)
pw, err := term.ReadPassword(int(terminal.Fd()))
fmt.Fprintf(terminal, "\r\n")
signal.Stop(sigCh)
term.Restore(int(terminal.Fd()), oldState)
if err != nil {
return "", fmt.Errorf("reading password: %w", err)
}
return string(pw), nil
}

1
go.mod
View File

@@ -24,4 +24,5 @@ require (
github.com/spf13/pflag v1.0.9 // indirect github.com/spf13/pflag v1.0.9 // indirect
golang.org/x/net v0.22.0 // indirect golang.org/x/net v0.22.0 // indirect
golang.org/x/sys v0.18.0 // indirect golang.org/x/sys v0.18.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
) )

334
main.go
View File

@@ -1,337 +1,7 @@
package main package main
import ( import "git.shee.sh/james/ward/cmd"
"crypto/tls"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/spf13/cobra"
)
// dbg is a package-level debug logger, active only when --debug is passed.
// Defaults to discarding output; serve command points it at stderr when --debug is set.
var dbg = log.New(io.Discard, "[ward] ", 0)
func main() { func main() {
root := &cobra.Command{ cmd.Execute()
Use: "ward",
Short: "Kubernetes credential gateway",
SilenceUsage: true,
SilenceErrors: true,
}
root.AddCommand(newServeCmd())
root.AddCommand(newCredentialCmd())
if err := root.Execute(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func newServeCmd() *cobra.Command {
fqdn := detectFQDN()
domain := domainPart(fqdn)
realm := strings.ToUpper(domain)
var (
addr string
tlsCert string
tlsKey string
k3sServer string
serverCACert string
clientCACert string
clientCAKey string
certDuration time.Duration
clusterName string
ldapDomain string
ldapOn bool
ldapURI string
ldapBindDN string
ldapBindPassword string
kerberosOn bool
keytabPath string
spn string
htpasswdPath string
debugFlag bool
)
cmd := &cobra.Command{
Use: "serve",
Short: "Run the ward authentication server",
RunE: func(cmd *cobra.Command, args []string) error {
if debugFlag {
dbg = log.New(os.Stderr, "[ward] ", log.Ltime)
dbg.Print("debug logging enabled")
}
if k3sServer == "" {
return fmt.Errorf("--k3s-server is required (e.g. https://k3s.example.com:6443)")
}
// ── LDAP ──────────────────────────────────────────────────────────────────
var ldapAuth *LDAPAuth
if ldapURI != "" || ldapOn {
var la *LDAPAuth
var err error
if ldapURI != "" {
la, err = NewLDAPAuthFromURI(ldapURI, ldapDomain, ldapBindDN, ldapBindPassword)
} else {
la, err = NewLDAPAuth(ldapDomain, ldapBindDN, ldapBindPassword)
}
if err != nil {
return fmt.Errorf("LDAP: %w", err)
}
ldapAuth = la
anon := ldapBindDN == ""
log.Printf("LDAP: %s:%d (TLS=%v) domain=%s anon=%v", la.host, la.port, la.useTLS, ldapDomain, anon)
}
if ldapAuth == nil && !kerberosOn && htpasswdPath == "" {
return fmt.Errorf("no authentication providers configured: use at least one of --ldap, --ldap-uri, --kerberos, or --htpasswd")
}
// ── Kerberos ──────────────────────────────────────────────────────────────
var krbAuth *KerberosAuth
if kerberosOn {
ka, err := NewKerberosAuth(keytabPath, spn)
if err != nil {
return fmt.Errorf("Kerberos: %w", err)
}
krbAuth = ka
log.Printf("Kerberos: keytab=%s SPN=%s realm=%s", keytabPath, spn, realm)
}
// ── htpasswd ──────────────────────────────────────────────────────────────
var htpasswdAuth *HtpasswdAuth
if htpasswdPath != "" {
ha, err := NewHtpasswdAuth(htpasswdPath)
if err != nil {
return fmt.Errorf("htpasswd: %w", err)
}
htpasswdAuth = ha
log.Printf("htpasswd: %s (%d entries)", htpasswdPath, ha.Len())
sighup := make(chan os.Signal, 1)
signal.Notify(sighup, syscall.SIGHUP)
go func() {
for range sighup {
if err := htpasswdAuth.Reload(); err != nil {
log.Printf("SIGHUP: htpasswd reload failed: %v", err)
} else {
log.Printf("SIGHUP: htpasswd reloaded (%d entries)", htpasswdAuth.Len())
}
}
}()
}
// ── Handler ───────────────────────────────────────────────────────────────
h, err := NewHandler(ldapAuth, krbAuth, htpasswdAuth, &CertConfig{
ServerURL: k3sServer,
ServerCACert: serverCACert,
ClientCACert: clientCACert,
ClientCAKey: clientCAKey,
Duration: certDuration,
ClusterName: clusterName,
})
if err != nil {
return fmt.Errorf("handler: %w", err)
}
// ── TLS ───────────────────────────────────────────────────────────────────
if _, err := tls.LoadX509KeyPair(tlsCert, tlsKey); err != nil {
return fmt.Errorf("TLS: loading certificate: %w", err)
}
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(tlsCert, tlsKey)
if err != nil {
return nil, fmt.Errorf("reloading TLS cert: %w", err)
}
return &cert, nil
},
}
ln, err := tls.Listen("tcp", addr, tlsConfig)
if err != nil {
return fmt.Errorf("listen %s: %w", addr, err)
}
mux := http.NewServeMux()
mux.HandleFunc("/kubeconfig", h.ServeHTTP)
mux.HandleFunc("/credential", h.ServeCredential)
mux.HandleFunc("/bootstrap", h.ServeBootstrap)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprintf(w, "ward — Kubernetes credential gateway\n\n"+
" GET /bootstrap kubeconfig with exec plugin pre-wired (no auth required)\n"+
" GET /credential ExecCredential JSON for kubectl exec plugin\n"+
" GET /kubeconfig kubeconfig with embedded client certificate\n")
})
srv := &http.Server{
Handler: mux,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
}
log.Printf("ward listening on %s (cert-duration=%s)", addr, certDuration)
return srv.Serve(ln)
},
}
cmd.Flags().StringVar(&addr, "addr", ":8443", "Listen address")
cmd.Flags().StringVar(&tlsCert, "tls-cert", fmt.Sprintf("/etc/letsencrypt/live/%s/fullchain.pem", fqdn), "TLS certificate file (Let's Encrypt fullchain)")
cmd.Flags().StringVar(&tlsKey, "tls-key", fmt.Sprintf("/etc/letsencrypt/live/%s/privkey.pem", fqdn), "TLS private key file")
cmd.Flags().StringVar(&k3sServer, "k3s-server", "", "k3s API server URL written into returned kubeconfigs (required, e.g. https://k3s.example.com:6443)")
cmd.Flags().StringVar(&serverCACert, "server-ca-cert", "/var/lib/rancher/k3s/server/tls/server-ca.crt", "k3s server CA certificate (embedded in returned kubeconfig)")
cmd.Flags().StringVar(&clientCACert, "client-ca-cert", "/var/lib/rancher/k3s/server/tls/client-ca.crt", "k3s client CA certificate (signs user certs)")
cmd.Flags().StringVar(&clientCAKey, "client-ca-key", "/var/lib/rancher/k3s/server/tls/client-ca.key", "k3s client CA key")
cmd.Flags().DurationVar(&certDuration, "cert-duration", 24*time.Hour, "Validity period of generated client certificates")
cmd.Flags().StringVar(&clusterName, "cluster-name", firstLabel(domain), "Cluster/context name written into generated kubeconfigs")
cmd.Flags().StringVar(&ldapDomain, "domain", domain, "Domain for LDAP SRV discovery and Kerberos realm derivation")
cmd.Flags().BoolVar(&ldapOn, "ldap", false, "Enable LDAP authentication (auto-discovered via DNS SRV)")
cmd.Flags().StringVar(&ldapURI, "ldap-uri", "", "LDAP server URI, e.g. ldaps://ldap.example.com (implies --ldap; overrides DNS SRV)")
cmd.Flags().StringVar(&ldapBindDN, "ldap-bind-dn", os.Getenv("WARD_LDAP_BIND_DN"), "LDAP bind DN for search (default: anonymous; env: WARD_LDAP_BIND_DN)")
cmd.Flags().StringVar(&ldapBindPassword, "ldap-bind-password", os.Getenv("WARD_LDAP_BIND_PASSWORD"), "LDAP bind password (env: WARD_LDAP_BIND_PASSWORD; caution: visible in ps output)")
cmd.Flags().BoolVar(&kerberosOn, "kerberos", false, "Enable Kerberos SPNEGO authentication (Authorization: Negotiate)")
cmd.Flags().StringVar(&keytabPath, "keytab", "/etc/krb5.keytab", "Kerberos service keytab path")
cmd.Flags().StringVar(&spn, "spn", "HTTP/"+fqdn, fmt.Sprintf("Kerberos service principal name (SPN) (default realm %s — create with: kadmin: addprinc -randkey HTTP/%s@%s)", realm, fqdn, realm))
cmd.Flags().StringVar(&htpasswdPath, "htpasswd", "", "Path to an Apache-compatible htpasswd file (bcrypt recommended: htpasswd -B -c file user)")
cmd.Flags().BoolVar(&debugFlag, "debug", os.Getenv("WARD_DEBUG") != "", "Enable verbose debug logging (also: $WARD_DEBUG=1)")
return cmd
}
func newCredentialCmd() *cobra.Command {
var (
server string
username string
noKerberos bool
noCache bool
debugFlag bool
)
cmd := &cobra.Command{
Use: "credential",
Short: "kubectl exec credential plugin — fetch an ExecCredential from ward",
Long: `Acts as a kubectl exec credential plugin. Fetches an ExecCredential JSON
from the ward server and prints it to stdout for kubectl to consume.
Authentication priority:
1. Kerberos SPNEGO using the active credential cache (from kinit)
2. Basic auth — prompts for password, or reads $WARD_PASSWORD
Credentials are cached in ~/.cache/ward/ and reused until 5 minutes
before expiry, so kubectl invocations are fast after the first call.
Debug output goes to stderr (kubectl surfaces this to the terminal):
WARD_DEBUG=1 kubectl get nodes`,
RunE: func(cmd *cobra.Command, args []string) error {
if server == "" {
return fmt.Errorf("--server is required")
}
// prepend https if no scheme is given, for user convenience
if !strings.Contains(server, "://") {
server = "https://" + server
}
// append port 8443 if no port is given
parts := strings.Split(server, ":")
if len(parts) == 2 {
server += ":8443"
}
logf := func(format string, a ...any) {
if debugFlag {
fmt.Fprintf(os.Stderr, "[ward] "+format+"\n", a...)
}
}
if !noCache {
if ec, ok := credReadCache(server, logf); ok {
return credPrint(ec)
}
}
ec, err := credFetch(server, username, noKerberos, logf)
if err != nil {
return err
}
if !noCache {
credWriteCache(server, ec, logf)
}
return credPrint(ec)
},
}
cmd.Flags().StringVar(&server, "server", "", "ward server URL (required)")
cmd.Flags().StringVar(&username, "username", "", "username for Basic auth fallback (default: $USER)")
cmd.Flags().BoolVar(&noKerberos, "no-kerberos", false, "skip Kerberos; always use Basic auth")
cmd.Flags().BoolVar(&noCache, "no-cache", false, "bypass local cache; always fetch a fresh credential")
cmd.Flags().BoolVar(&debugFlag, "debug", os.Getenv("WARD_DEBUG") != "", "verbose debug output to stderr (also: $WARD_DEBUG=1)")
return cmd
}
// detectFQDN returns the fully-qualified domain name of the local host,
// falling back to the short hostname if DNS resolution fails.
func detectFQDN() string {
hostname, err := os.Hostname()
if err != nil {
return "localhost"
}
if strings.Contains(hostname, ".") {
return hostname
}
addrs, err := net.LookupHost(hostname)
if err != nil || len(addrs) == 0 {
return hostname
}
names, err := net.LookupAddr(addrs[0])
if err != nil || len(names) == 0 {
return hostname
}
return strings.TrimSuffix(names[0], ".")
}
// domainPart strips the first label from a FQDN.
// "host.example.com" → "example.com"
func domainPart(fqdn string) string {
if idx := strings.IndexByte(fqdn, '.'); idx >= 0 {
return fqdn[idx+1:]
}
return fqdn
}
// firstLabel returns the first dot-separated label of a domain name.
// "example.com" → "example"
func firstLabel(domain string) string {
if idx := strings.IndexByte(domain, '.'); idx >= 0 {
return domain[:idx]
}
return domain
} }

View File

@@ -1,4 +1,4 @@
package main package auth
import ( import (
"bufio" "bufio"

View File

@@ -1,4 +1,4 @@
package main package auth
import ( import (
"fmt" "fmt"

View File

@@ -1,8 +1,9 @@
package main package auth
import ( import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"log"
"net" "net"
"net/url" "net/url"
"strconv" "strconv"
@@ -19,6 +20,7 @@ type LDAPAuth struct {
useTLS bool // true = LDAPS (TLS from the start); false = STARTTLS on port 389 useTLS bool // true = LDAPS (TLS from the start); false = STARTTLS on port 389
bindDN string // empty = anonymous bind bindDN string // empty = anonymous bind
bindPassword string bindPassword string
log *log.Logger
} }
// NewLDAPAuth discovers the LDAP server for domain via the standard _ldap._tcp // NewLDAPAuth discovers the LDAP server for domain via the standard _ldap._tcp
@@ -26,7 +28,7 @@ type LDAPAuth struct {
// port 389 uses STARTTLS; anything else (typically 636) uses LDAPS. // port 389 uses STARTTLS; anything else (typically 636) uses LDAPS.
// _ldaps._tcp is not an IANA-registered SRV type and is not consulted. // _ldaps._tcp is not an IANA-registered SRV type and is not consulted.
// bindDN and bindPassword are used for the search bind; both empty = anonymous. // bindDN and bindPassword are used for the search bind; both empty = anonymous.
func NewLDAPAuth(domain, bindDN, bindPassword string) (*LDAPAuth, error) { func NewLDAPAuth(domain, bindDN, bindPassword string, dbg *log.Logger) (*LDAPAuth, error) {
host, port, err := discoverLDAP(domain) host, port, err := discoverLDAP(domain)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -38,6 +40,7 @@ func NewLDAPAuth(domain, bindDN, bindPassword string) (*LDAPAuth, error) {
useTLS: port != 389, useTLS: port != 389,
bindDN: bindDN, bindDN: bindDN,
bindPassword: bindPassword, bindPassword: bindPassword,
log: dbg,
}, nil }, nil
} }
@@ -46,7 +49,7 @@ func NewLDAPAuth(domain, bindDN, bindPassword string) (*LDAPAuth, error) {
// Port defaults to 389 for ldap:// and 636 for ldaps:// if not specified. // Port defaults to 389 for ldap:// and 636 for ldaps:// if not specified.
// domain is still required for base-DN derivation and UPN construction. // domain is still required for base-DN derivation and UPN construction.
// bindDN and bindPassword are used for the search bind; both empty = anonymous. // bindDN and bindPassword are used for the search bind; both empty = anonymous.
func NewLDAPAuthFromURI(rawURI, domain, bindDN, bindPassword string) (*LDAPAuth, error) { func NewLDAPAuthFromURI(rawURI, domain, bindDN, bindPassword string, dbg *log.Logger) (*LDAPAuth, error) {
u, err := url.Parse(rawURI) u, err := url.Parse(rawURI)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid LDAP URI %q: %w", rawURI, err) return nil, fmt.Errorf("invalid LDAP URI %q: %w", rawURI, err)
@@ -78,9 +81,19 @@ func NewLDAPAuthFromURI(rawURI, domain, bindDN, bindPassword string) (*LDAPAuth,
useTLS: useTLS, useTLS: useTLS,
bindDN: bindDN, bindDN: bindDN,
bindPassword: bindPassword, bindPassword: bindPassword,
log: dbg,
}, nil }, nil
} }
// Host returns the LDAP server hostname (used for logging).
func (a *LDAPAuth) Host() string { return a.host }
// Port returns the LDAP server port (used for logging).
func (a *LDAPAuth) Port() int { return a.port }
// UseTLS reports whether TLS is in use (used for logging).
func (a *LDAPAuth) UseTLS() bool { return a.useTLS }
func discoverLDAP(domain string) (host string, port int, err error) { func discoverLDAP(domain string) (host string, port int, err error) {
_, addrs, err := net.LookupSRV("ldap", "tcp", domain) _, addrs, err := net.LookupSRV("ldap", "tcp", domain)
if err != nil || len(addrs) == 0 { if err != nil || len(addrs) == 0 {
@@ -93,10 +106,10 @@ func (a *LDAPAuth) connect() (*ldap.Conn, error) {
addr := fmt.Sprintf("%s:%d", a.host, a.port) addr := fmt.Sprintf("%s:%d", a.host, a.port)
tlsConfig := &tls.Config{ServerName: a.host} tlsConfig := &tls.Config{ServerName: a.host}
if a.useTLS { if a.useTLS {
dbg.Printf("LDAP: dialing TLS %s", addr) a.log.Printf("LDAP: dialing TLS %s", addr)
return ldap.DialTLS("tcp", addr, tlsConfig) return ldap.DialTLS("tcp", addr, tlsConfig)
} }
dbg.Printf("LDAP: dialing %s + STARTTLS", addr) a.log.Printf("LDAP: dialing %s + STARTTLS", addr)
conn, err := ldap.Dial("tcp", addr) conn, err := ldap.Dial("tcp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -157,27 +170,27 @@ func (a *LDAPAuth) Authenticate(username, password string) (groups []string, err
return nil, fmt.Errorf("LDAP search bind failed: %w", err) return nil, fmt.Errorf("LDAP search bind failed: %w", err)
} }
dbg.Printf("LDAP: searching for user %q", username) a.log.Printf("LDAP: searching for user %q", username)
userDN, err := a.findUserDN(conn, username) userDN, err := a.findUserDN(conn, username)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid credentials") return nil, fmt.Errorf("invalid credentials")
} }
// Verify the password by binding as the user. // Verify the password by binding as the user.
dbg.Printf("LDAP: binding as %s", userDN) a.log.Printf("LDAP: binding as %s", userDN)
if err := conn.Bind(userDN, password); err != nil { if err := conn.Bind(userDN, password); err != nil {
return nil, fmt.Errorf("invalid credentials") return nil, fmt.Errorf("invalid credentials")
} }
dbg.Printf("LDAP: bind OK for %s", userDN) a.log.Printf("LDAP: bind OK for %s", userDN)
// Re-bind as service account for group lookup — the user may lack read access. // Re-bind as service account for group lookup — the user may lack read access.
if err := a.searchBind(conn); err != nil { if err := a.searchBind(conn); err != nil {
dbg.Printf("LDAP: re-bind for group lookup failed: %v — skipping groups", err) a.log.Printf("LDAP: re-bind for group lookup failed: %v — skipping groups", err)
return nil, nil // auth succeeded; group lookup is best-effort return nil, nil // auth succeeded; group lookup is best-effort
} }
groups = a.lookupGroups(conn, username, userDN) groups = a.lookupGroups(conn, username, userDN)
dbg.Printf("LDAP: groups for %s: %v", username, groups) a.log.Printf("LDAP: groups for %s: %v", username, groups)
return groups, nil return groups, nil
} }

View File

@@ -1,4 +1,4 @@
package main package cert
import ( import (
"bytes" "bytes"
@@ -19,12 +19,13 @@ import (
// CertConfig holds paths and settings for certificate and kubeconfig generation. // CertConfig holds paths and settings for certificate and kubeconfig generation.
type CertConfig struct { type CertConfig struct {
ServerURL string // written as the cluster.server in the returned kubeconfig ServerURL string // written as the cluster.server in the returned kubeconfig
ServerCACert string // path — embedded in the kubeconfig so kubectl can verify the API server ServerCACert string // path — embedded in the kubeconfig so kubectl can verify the API server
ClientCACert string // path — signs the per-user client certificates ClientCACert string // path — signs the per-user client certificates
ClientCAKey string // path ClientCAKey string // path
Duration time.Duration // validity period for generated client certs Duration time.Duration // validity period for generated client certs
ClusterName string // name used for the cluster/context in generated kubeconfigs LoginDuration time.Duration // validity period for certs issued by 'ward login'
ClusterName string // name used for the cluster/context in generated kubeconfigs
} }
// KubeconfigGenerator loads the k3s CAs once and issues per-user kubeconfigs on demand. // KubeconfigGenerator loads the k3s CAs once and issues per-user kubeconfigs on demand.
@@ -55,6 +56,9 @@ func NewKubeconfigGenerator(cfg *CertConfig) (*KubeconfigGenerator, error) {
}, nil }, nil
} }
// Cfg returns the CertConfig (used by handler to read Duration/LoginDuration).
func (g *KubeconfigGenerator) Cfg() *CertConfig { return g.cfg }
// Credential holds the raw PEM blobs and expiry for a generated client certificate. // Credential holds the raw PEM blobs and expiry for a generated client certificate.
// Used both for kubeconfig generation and the /credential exec-plugin endpoint. // Used both for kubeconfig generation and the /credential exec-plugin endpoint.
type Credential struct { type Credential struct {
@@ -65,8 +69,8 @@ type Credential struct {
// GenerateCredential signs a fresh client certificate and returns the raw PEM data. // GenerateCredential signs a fresh client certificate and returns the raw PEM data.
// Use this when you need the cert material directly (e.g. the exec credential plugin). // Use this when you need the cert material directly (e.g. the exec credential plugin).
func (g *KubeconfigGenerator) GenerateCredential(username string, groups []string) (*Credential, error) { func (g *KubeconfigGenerator) GenerateCredential(username string, groups []string, duration time.Duration) (*Credential, error) {
certPEM, keyPEM, err := g.signClientCert(username, groups) certPEM, keyPEM, err := g.signClientCert(username, groups, duration)
if err != nil { if err != nil {
return nil, fmt.Errorf("signing cert for %s: %w", username, err) return nil, fmt.Errorf("signing cert for %s: %w", username, err)
} }
@@ -82,7 +86,7 @@ func (g *KubeconfigGenerator) GenerateCredential(username string, groups []strin
// groups are embedded as the certificate's Organisation field, which Kubernetes reads // groups are embedded as the certificate's Organisation field, which Kubernetes reads
// as RBAC group memberships. // as RBAC group memberships.
func (g *KubeconfigGenerator) Generate(username string, groups []string) ([]byte, error) { func (g *KubeconfigGenerator) Generate(username string, groups []string) ([]byte, error) {
cred, err := g.GenerateCredential(username, groups) cred, err := g.GenerateCredential(username, groups, g.cfg.Duration)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -97,7 +101,7 @@ func (g *KubeconfigGenerator) Generate(username string, groups []string) ([]byte
} }
// signClientCert issues an ECDSA P-256 client certificate signed by the k3s client CA. // signClientCert issues an ECDSA P-256 client certificate signed by the k3s client CA.
func (g *KubeconfigGenerator) signClientCert(username string, groups []string) (certPEM, keyPEM []byte, err error) { func (g *KubeconfigGenerator) signClientCert(username string, groups []string, duration time.Duration) (certPEM, keyPEM []byte, err error) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("generating key: %w", err) return nil, nil, fmt.Errorf("generating key: %w", err)
@@ -116,7 +120,7 @@ func (g *KubeconfigGenerator) signClientCert(username string, groups []string) (
Organization: groups, // Kubernetes maps these to RBAC groups Organization: groups, // Kubernetes maps these to RBAC groups
}, },
NotBefore: now.Add(-5 * time.Minute), // tolerate minor clock skew NotBefore: now.Add(-5 * time.Minute), // tolerate minor clock skew
NotAfter: now.Add(g.cfg.Duration), NotAfter: now.Add(duration),
KeyUsage: x509.KeyUsageDigitalSignature, KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true, BasicConstraintsValid: true,
@@ -186,11 +190,11 @@ func loadCA(certFile, keyFile string) (*x509.Certificate, crypto.PrivateKey, err
func (g *KubeconfigGenerator) GenerateBootstrap(wardURL, username string) ([]byte, error) { func (g *KubeconfigGenerator) GenerateBootstrap(wardURL, username string) ([]byte, error) {
var buf bytes.Buffer var buf bytes.Buffer
err := bootstrapTmpl.Execute(&buf, map[string]string{ err := bootstrapTmpl.Execute(&buf, map[string]string{
"Server": g.cfg.ServerURL, "Server": g.cfg.ServerURL,
"ServerCA": g.serverCA, "ServerCA": g.serverCA,
"WardURL": wardURL, "WardURL": wardURL,
"Username": username, "Username": username,
"Cluster": g.cfg.ClusterName, "Cluster": g.cfg.ClusterName,
}) })
return buf.Bytes(), err return buf.Bytes(), err
} }

View File

@@ -1,4 +1,4 @@
package main package handler
import ( import (
"encoding/json" "encoding/json"
@@ -8,6 +8,9 @@ import (
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"time" "time"
"git.shee.sh/james/ward/pkg/auth"
"git.shee.sh/james/ward/pkg/cert"
) )
// ExecCredential is the JSON structure kubectl expects from an exec credential plugin. // ExecCredential is the JSON structure kubectl expects from an exec credential plugin.
@@ -28,23 +31,24 @@ type ExecCredentialStatus struct {
// Handler wires together authentication providers and kubeconfig generation. // Handler wires together authentication providers and kubeconfig generation.
// At least one provider must be non-nil. // At least one provider must be non-nil.
type Handler struct { type Handler struct {
ldap *LDAPAuth ldap *auth.LDAPAuth
krb *KerberosAuth krb *auth.KerberosAuth
htpasswd *HtpasswdAuth htpasswd *auth.HtpasswdAuth
gen *KubeconfigGenerator gen *cert.KubeconfigGenerator
log *log.Logger
} }
// NewHandler validates that at least one auth provider is configured, then // NewHandler validates that at least one auth provider is configured, then
// loads the k3s CA files and returns a ready Handler. // loads the k3s CA files and returns a ready Handler.
func NewHandler(ldap *LDAPAuth, krb *KerberosAuth, htpasswd *HtpasswdAuth, cfg *CertConfig) (*Handler, error) { func NewHandler(ldapAuth *auth.LDAPAuth, krbAuth *auth.KerberosAuth, htpasswdAuth *auth.HtpasswdAuth, cfg *cert.CertConfig, dbg *log.Logger) (*Handler, error) {
if ldap == nil && krb == nil && htpasswd == nil { if ldapAuth == nil && krbAuth == nil && htpasswdAuth == nil {
return nil, fmt.Errorf("no authentication providers configured: enable at least one of LDAP, --kerberos, or --htpasswd") return nil, fmt.Errorf("no authentication providers configured: enable at least one of LDAP, --kerberos, or --htpasswd")
} }
gen, err := NewKubeconfigGenerator(cfg) gen, err := cert.NewKubeconfigGenerator(cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &Handler{ldap: ldap, krb: krb, htpasswd: htpasswd, gen: gen}, nil return &Handler{ldap: ldapAuth, krb: krbAuth, htpasswd: htpasswdAuth, gen: gen, log: dbg}, nil
} }
// ServeHTTP handles GET /kubeconfig — returns a kubeconfig YAML on success. // ServeHTTP handles GET /kubeconfig — returns a kubeconfig YAML on success.
@@ -83,10 +87,13 @@ func (h *Handler) ServeCredential(w http.ResponseWriter, r *http.Request) {
if !ok { if !ok {
return return
} }
dbg.Printf("issuing exec credential for %q groups=%v", username, groups)
log.Printf("issuing exec credential for %q groups=%v", username, groups) log.Printf("issuing exec credential for %q groups=%v", username, groups)
cred, err := h.gen.GenerateCredential(username, groups) duration := h.gen.Cfg().Duration
if r.URL.Query().Get("login") == "true" {
duration = h.gen.Cfg().LoginDuration
}
cred, err := h.gen.GenerateCredential(username, groups, duration)
if err != nil { if err != nil {
log.Printf("credential generation failed for %q: %v", username, err) log.Printf("credential generation failed for %q: %v", username, err)
http.Error(w, "internal server error", http.StatusInternalServerError) http.Error(w, "internal server error", http.StatusInternalServerError)
@@ -160,18 +167,18 @@ func (h *Handler) authenticate(w http.ResponseWriter, r *http.Request) (string,
switch { switch {
case h.krb != nil && strings.HasPrefix(authHeader, "Negotiate "): case h.krb != nil && strings.HasPrefix(authHeader, "Negotiate "):
dbg.Printf("auth: trying Kerberos SPNEGO") h.log.Printf("auth: trying Kerberos SPNEGO")
username, err := h.krb.Authenticate(w, r) username, err := h.krb.Authenticate(w, r)
if err != nil { if err != nil {
log.Printf("Kerberos auth failed: %v", err) log.Printf("Kerberos auth failed: %v", err)
h.sendChallenge(w, true, h.ldap != nil || h.htpasswd != nil) h.sendChallenge(w, true, h.ldap != nil || h.htpasswd != nil)
return "", nil, false return "", nil, false
} }
dbg.Printf("auth: Kerberos OK, user=%q", username) h.log.Printf("auth: Kerberos OK, user=%q", username)
var groups []string var groups []string
if h.ldap != nil { if h.ldap != nil {
groups = h.ldap.LookupGroups(username) groups = h.ldap.LookupGroups(username)
dbg.Printf("auth: LDAP group lookup for %q → %v", username, groups) h.log.Printf("auth: LDAP group lookup for %q → %v", username, groups)
} }
return username, groups, true return username, groups, true
@@ -181,18 +188,18 @@ func (h *Handler) authenticate(w http.ResponseWriter, r *http.Request) (string,
h.sendChallenge(w, h.krb != nil, true) h.sendChallenge(w, h.krb != nil, true)
return "", nil, false return "", nil, false
} }
dbg.Printf("auth: trying Basic for user=%q", user) h.log.Printf("auth: trying Basic for user=%q", user)
groups, err := h.authenticateBasic(user, password) groups, err := h.authenticateBasic(user, password)
if err != nil { if err != nil {
log.Printf("Basic auth failed for %q: %v", user, err) log.Printf("Basic auth failed for %q: %v", user, err)
h.sendChallenge(w, h.krb != nil, true) h.sendChallenge(w, h.krb != nil, true)
return "", nil, false return "", nil, false
} }
dbg.Printf("auth: Basic OK, user=%q groups=%v", user, groups) h.log.Printf("auth: Basic OK, user=%q groups=%v", user, groups)
return user, groups, true return user, groups, true
default: default:
dbg.Printf("auth: no Authorization header, sending challenges") h.log.Printf("auth: no Authorization header, sending challenges")
h.sendChallenge(w, h.krb != nil, h.ldap != nil || h.htpasswd != nil) h.sendChallenge(w, h.krb != nil, h.ldap != nil || h.htpasswd != nil)
return "", nil, false return "", nil, false
} }
@@ -206,7 +213,7 @@ func (h *Handler) authenticateBasic(username, password string) ([]string, error)
if err == nil { if err == nil {
return groups, nil return groups, nil
} }
dbg.Printf("LDAP auth failed for %q: %v", username, err) h.log.Printf("LDAP auth failed for %q: %v", username, err)
if h.htpasswd == nil { if h.htpasswd == nil {
return nil, err return nil, err
} }

View File

@@ -0,0 +1,165 @@
package kubeconfig
import (
"fmt"
"os"
"path/filepath"
"strings"
"gopkg.in/yaml.v3"
)
// KubeConfig is the top-level kubeconfig structure.
type KubeConfig struct {
APIVersion string `yaml:"apiVersion"`
Kind string `yaml:"kind"`
Preferences map[string]any `yaml:"preferences,omitempty"`
Clusters []NamedCluster `yaml:"clusters"`
Users []NamedUser `yaml:"users"`
Contexts []NamedContext `yaml:"contexts"`
CurrentContext string `yaml:"current-context,omitempty"`
}
// NamedCluster is a named cluster entry.
type NamedCluster struct {
Name string `yaml:"name"`
Cluster ClusterData `yaml:"cluster"`
}
// ClusterData holds the cluster connection details.
type ClusterData struct {
Server string `yaml:"server"`
CertificateAuthorityData string `yaml:"certificate-authority-data,omitempty"`
}
// NamedUser is a named user entry.
type NamedUser struct {
Name string `yaml:"name"`
User UserData `yaml:"user"`
}
// UserData holds user credentials.
type UserData struct {
Exec *ExecData `yaml:"exec,omitempty"`
}
// ExecData holds exec plugin configuration.
type ExecData struct {
APIVersion string `yaml:"apiVersion"`
Command string `yaml:"command"`
Args []string `yaml:"args,omitempty"`
InteractiveMode string `yaml:"interactiveMode,omitempty"`
}
// NamedContext is a named context entry.
type NamedContext struct {
Name string `yaml:"name"`
Context ContextData `yaml:"context"`
}
// ContextData holds context details.
type ContextData struct {
Cluster string `yaml:"cluster"`
User string `yaml:"user"`
}
// FilePath returns the path to the active kubeconfig file.
// If KUBECONFIG is set to a colon-separated list, the first entry is returned.
func FilePath() string {
if k := os.Getenv("KUBECONFIG"); k != "" {
if idx := strings.IndexByte(k, os.PathListSeparator); idx >= 0 {
return k[:idx]
}
return k
}
home, _ := os.UserHomeDir()
return filepath.Join(home, ".kube", "config")
}
// RenameContext renames the cluster and context (but not the user) in cfg.
// The bootstrap template uses the cluster name as both the cluster and context
// name; the user name is the actual username and is left unchanged.
func RenameContext(cfg *KubeConfig, newName string) {
oldName := cfg.CurrentContext
if oldName == newName {
return
}
for i := range cfg.Clusters {
if cfg.Clusters[i].Name == oldName {
cfg.Clusters[i].Name = newName
}
}
for i := range cfg.Contexts {
if cfg.Contexts[i].Name == oldName {
cfg.Contexts[i].Name = newName
cfg.Contexts[i].Context.Cluster = newName
}
}
cfg.CurrentContext = newName
}
// Merge merges incoming into the kubeconfig file at FilePath().
// If setContext is true, the current-context is updated to incoming.CurrentContext.
func Merge(incoming *KubeConfig, setContext bool) error {
path := FilePath()
existing := &KubeConfig{APIVersion: "v1", Kind: "Config"}
if data, err := os.ReadFile(path); err == nil {
if err := yaml.Unmarshal(data, existing); err != nil {
return fmt.Errorf("parsing existing kubeconfig: %w", err)
}
}
for _, c := range incoming.Clusters {
existing.Clusters = upsertCluster(existing.Clusters, c)
}
for _, u := range incoming.Users {
existing.Users = upsertUser(existing.Users, u)
}
for _, ctx := range incoming.Contexts {
existing.Contexts = upsertContext(existing.Contexts, ctx)
}
if setContext {
existing.CurrentContext = incoming.CurrentContext
}
data, err := yaml.Marshal(existing)
if err != nil {
return fmt.Errorf("marshaling kubeconfig: %w", err)
}
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
return fmt.Errorf("creating kubeconfig directory: %w", err)
}
return os.WriteFile(path, data, 0600)
}
func upsertCluster(list []NamedCluster, item NamedCluster) []NamedCluster {
for i, c := range list {
if c.Name == item.Name {
list[i] = item
return list
}
}
return append(list, item)
}
func upsertUser(list []NamedUser, item NamedUser) []NamedUser {
for i, u := range list {
if u.Name == item.Name {
list[i] = item
return list
}
}
return append(list, item)
}
func upsertContext(list []NamedContext, item NamedContext) []NamedContext {
for i, c := range list {
if c.Name == item.Name {
list[i] = item
return list
}
}
return append(list, item)
}