Files
ward/cmd_credential.go
2026-03-02 15:19:32 +01:00

317 lines
10 KiB
Go

package main
import (
"crypto/sha256"
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"golang.org/x/term"
krb5client "github.com/jcmturner/gokrb5/v8/client"
"github.com/jcmturner/gokrb5/v8/config"
"github.com/jcmturner/gokrb5/v8/credentials"
"github.com/jcmturner/gokrb5/v8/spnego"
)
// runCredential implements the "ward credential" subcommand, which acts as a
// kubectl exec credential plugin. It fetches an ExecCredential JSON from the
// ward server and prints it to stdout for kubectl to consume.
//
// Authentication priority:
// 1. Kerberos SPNEGO using the active credential cache (from kinit)
// 2. Basic auth — prompts for password, or reads $WARD_PASSWORD
//
// Credentials are cached in ~/.cache/ward/ and reused until 5 minutes
// before expiry, so kubectl invocations are fast after the first call.
//
// Debug output goes to stderr (kubectl surfaces this to the terminal):
//
// WARD_DEBUG=1 kubectl get nodes
func runCredential(args []string) int {
fs := flag.NewFlagSet("credential", flag.ContinueOnError)
fs.SetOutput(os.Stderr)
server := fs.String("server", "", "ward server URL (required)")
username := fs.String("username", "", "username for Basic auth fallback (default: $USER)")
noKerberos := fs.Bool("no-kerberos", false, "skip Kerberos; always use Basic auth")
noCache := fs.Bool("no-cache", false, "bypass local cache; always fetch a fresh credential")
debug := fs.Bool("debug", os.Getenv("WARD_DEBUG") != "", "verbose debug output to stderr (also: $WARD_DEBUG=1)")
if err := fs.Parse(args); err != nil {
return 1
}
if *server == "" {
fmt.Fprintln(os.Stderr, "ward credential: --server is required")
fs.PrintDefaults()
return 1
}
logf := func(format string, a ...interface{}) {
if *debug {
fmt.Fprintf(os.Stderr, "[ward] "+format+"\n", a...)
}
}
// ── Cache ─────────────────────────────────────────────────────────────────
if !*noCache {
if ec, ok := credReadCache(*server, logf); ok {
return credPrint(ec)
}
}
// ── Fetch from ward ───────────────────────────────────────────────────
ec, err := credFetch(*server, *username, *noKerberos, logf)
if err != nil {
fmt.Fprintf(os.Stderr, "ward credential: %v\n", err)
return 1
}
if !*noCache {
credWriteCache(*server, ec, logf)
}
return credPrint(ec)
}
func credFetch(server, username string, noKerberos bool, logf func(string, ...interface{})) (*ExecCredential, error) {
url := strings.TrimRight(server, "/") + "/credential"
// ── Kerberos SPNEGO ───────────────────────────────────────────────────────
if !noKerberos {
body, err := credFetchKerberos(url, logf)
if err != nil {
logf("Kerberos: %v — falling back to Basic auth", err)
fmt.Fprintf(os.Stderr, "ward: Kerberos failed (%v); falling back to Basic auth\nhint: run 'kinit' to avoid the password prompt\n", err)
} else {
return credParse(body)
}
}
// ── Basic auth ────────────────────────────────────────────────────────────
if username == "" {
username = os.Getenv("USER")
}
password := os.Getenv("WARD_PASSWORD")
if password == "" {
var err error
password, err = credPromptPassword(username)
if err != nil {
return nil, err
}
} else {
logf("Basic: using password from $WARD_PASSWORD")
}
body, err := credFetchBasic(url, username, password, logf)
if err != nil {
return nil, err
}
return credParse(body)
}
func credFetchKerberos(url string, logf func(string, ...interface{})) ([]byte, error) {
krb5cfgPath := krb5ConfigPath()
logf("Kerberos: loading config from %s", krb5cfgPath)
krb5cfg, err := config.Load(krb5cfgPath)
if err != nil {
return nil, fmt.Errorf("loading krb5 config: %w", err)
}
ccPath := ccachePath()
logf("Kerberos: loading credential cache from %s", ccPath)
ccache, err := credentials.LoadCCache(ccPath)
if err != nil {
return nil, fmt.Errorf("no credential cache (%w) — run 'kinit'", err)
}
cl, err := krb5client.NewFromCCache(ccache, krb5cfg, krb5client.DisablePAFXFAST(true))
if err != nil {
return nil, fmt.Errorf("creating Kerberos client: %w", err)
}
defer cl.Destroy()
logf("Kerberos: principal=%s@%s", cl.Credentials.UserName(), cl.Credentials.Domain())
spnegoClient := spnego.NewClient(cl, nil, "")
logf("HTTP: GET %s (Negotiate)", url)
resp, err := spnegoClient.Get(url)
if err != nil {
return nil, fmt.Errorf("HTTP request: %w", err)
}
defer resp.Body.Close()
logf("HTTP: %s", resp.Status)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading response: %w", err)
}
if resp.StatusCode == http.StatusUnauthorized {
return nil, fmt.Errorf("server rejected Kerberos token (check SPN and keytab)")
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("server returned %s", resp.Status)
}
logf("HTTP: received %d bytes", len(body))
return body, nil
}
func credFetchBasic(url, username, password string, logf func(string, ...interface{})) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("building request: %w", err)
}
req.SetBasicAuth(username, password)
logf("HTTP: GET %s (Basic, user=%s)", url, username)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("HTTP request: %w", err)
}
defer resp.Body.Close()
logf("HTTP: %s", resp.Status)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading response: %w", err)
}
if resp.StatusCode == http.StatusUnauthorized {
return nil, fmt.Errorf("authentication failed (wrong username or password)")
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("server returned %s: %s", resp.Status, strings.TrimSpace(string(body)))
}
logf("HTTP: received %d bytes", len(body))
return body, nil
}
func credParse(body []byte) (*ExecCredential, error) {
var ec ExecCredential
if err := json.Unmarshal(body, &ec); err != nil {
return nil, fmt.Errorf("parsing server response: %w", err)
}
if ec.Status == nil {
return nil, fmt.Errorf("server returned ExecCredential with no status field")
}
return &ec, nil
}
func credPrint(ec *ExecCredential) int {
if err := json.NewEncoder(os.Stdout).Encode(ec); err != nil {
fmt.Fprintf(os.Stderr, "ward credential: writing output: %v\n", err)
return 1
}
return 0
}
// ── Local credential cache ─────────────────────────────────────────────────────
// Caches ExecCredential JSON in ~/.cache/ward/<sha256(serverURL)>.json.
// The cache is consulted before contacting ward; a hit avoids a round-trip
// and, for Kerberos, avoids acquiring a service ticket on every kubectl call.
func credCacheDir() string {
if d := os.Getenv("XDG_CACHE_HOME"); d != "" {
return filepath.Join(d, "ward")
}
home, _ := os.UserHomeDir()
return filepath.Join(home, ".cache", "ward")
}
func credCacheFile(serverURL string) string {
h := sha256.Sum256([]byte(serverURL))
return filepath.Join(credCacheDir(), fmt.Sprintf("%x.json", h))
}
func credReadCache(serverURL string, logf func(string, ...interface{})) (*ExecCredential, bool) {
path := credCacheFile(serverURL)
logf("cache: checking %s", path)
data, err := os.ReadFile(path)
if err != nil {
logf("cache: miss (%v)", err)
return nil, false
}
var ec ExecCredential
if err := json.Unmarshal(data, &ec); err != nil {
logf("cache: corrupt, ignoring (%v)", err)
return nil, false
}
if ec.Status == nil || ec.Status.ExpirationTimestamp == "" {
logf("cache: no expiry stored, refreshing")
return nil, false
}
expiry, err := time.Parse(time.RFC3339, ec.Status.ExpirationTimestamp)
if err != nil {
logf("cache: unparseable expiry %q, refreshing", ec.Status.ExpirationTimestamp)
return nil, false
}
remaining := time.Until(expiry)
if remaining < 5*time.Minute {
logf("cache: expiring soon (%v remaining), refreshing", remaining.Truncate(time.Second))
return nil, false
}
logf("cache: hit — cert for expires %s (%v remaining)",
expiry.Format(time.RFC3339), remaining.Truncate(time.Second))
return &ec, true
}
func credWriteCache(serverURL string, ec *ExecCredential, logf func(string, ...interface{})) {
dir := credCacheDir()
if err := os.MkdirAll(dir, 0700); err != nil {
logf("cache: failed to create dir: %v", err)
return
}
data, err := json.Marshal(ec)
if err != nil {
logf("cache: marshal failed: %v", err)
return
}
path := credCacheFile(serverURL)
if err := os.WriteFile(path, data, 0600); err != nil {
logf("cache: write failed: %v", err)
return
}
logf("cache: saved to %s", path)
}
// ── Kerberos helpers ───────────────────────────────────────────────────────────
func krb5ConfigPath() string {
if v := os.Getenv("KRB5_CONFIG"); v != "" {
return v
}
return "/etc/krb5.conf"
}
// ccachePath returns the path to the active Kerberos credential cache.
// Respects $KRB5CCNAME; strips the "FILE:" prefix if present.
// Non-file ccache types (API:, KEYRING:, DIR:) are not supported by gokrb5
// and will produce an error when LoadCCache is called.
func ccachePath() string {
if v := os.Getenv("KRB5CCNAME"); v != "" {
return strings.TrimPrefix(v, "FILE:")
}
return fmt.Sprintf("/tmp/krb5cc_%d", os.Getuid())
}
// ── Password prompt ────────────────────────────────────────────────────────────
func credPromptPassword(username string) (string, error) {
if !term.IsTerminal(int(os.Stdin.Fd())) {
return "", fmt.Errorf(
"stdin is not a terminal and $WARD_PASSWORD is not set\n" +
"hint: run 'kinit' for Kerberos auth, or set $WARD_PASSWORD for non-interactive use")
}
fmt.Fprintf(os.Stderr, "Password for %s: ", username)
pw, err := term.ReadPassword(int(os.Stdin.Fd()))
fmt.Fprintln(os.Stderr) // newline after the hidden input
if err != nil {
return "", fmt.Errorf("reading password: %w", err)
}
return string(pw), nil
}