271 lines
8.5 KiB
Go
271 lines
8.5 KiB
Go
package main
|
|
|
|
import (
|
|
"cmp"
|
|
"crypto/sha256"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"golang.org/x/term"
|
|
|
|
krb5client "github.com/jcmturner/gokrb5/v8/client"
|
|
"github.com/jcmturner/gokrb5/v8/config"
|
|
"github.com/jcmturner/gokrb5/v8/credentials"
|
|
"github.com/jcmturner/gokrb5/v8/spnego"
|
|
)
|
|
|
|
func credFetch(server, username string, noKerberos bool, logf func(string, ...interface{})) (*ExecCredential, error) {
|
|
url := strings.TrimRight(server, "/") + "/credential"
|
|
|
|
// ── Kerberos SPNEGO ───────────────────────────────────────────────────────
|
|
if !noKerberos {
|
|
body, err := credFetchKerberos(url, logf)
|
|
if err != nil {
|
|
logf("Kerberos: %v — falling back to Basic auth", err)
|
|
fmt.Fprintf(os.Stderr, "ward: Kerberos failed (%v); falling back to Basic auth\nhint: run 'kinit' to avoid the password prompt\n", err)
|
|
} else {
|
|
return credParse(body)
|
|
}
|
|
}
|
|
|
|
// ── Basic auth ────────────────────────────────────────────────────────────
|
|
if username == "" {
|
|
username = os.Getenv("USER")
|
|
}
|
|
password := os.Getenv("WARD_PASSWORD")
|
|
if password == "" {
|
|
var err error
|
|
password, err = credPromptPassword(username)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
logf("Basic: using password from $WARD_PASSWORD")
|
|
}
|
|
|
|
body, err := credFetchBasic(url, username, password, logf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return credParse(body)
|
|
}
|
|
|
|
func credFetchKerberos(url string, logf func(string, ...interface{})) ([]byte, error) {
|
|
krb5cfgPath := krb5ConfigPath()
|
|
logf("Kerberos: loading config from %s", krb5cfgPath)
|
|
krb5cfg, err := config.Load(krb5cfgPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("loading krb5 config: %w", err)
|
|
}
|
|
|
|
ccPath := ccachePath()
|
|
logf("Kerberos: loading credential cache from %s", ccPath)
|
|
ccache, err := credentials.LoadCCache(ccPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("no credential cache (%w) — run 'kinit'", err)
|
|
}
|
|
|
|
cl, err := krb5client.NewFromCCache(ccache, krb5cfg, krb5client.DisablePAFXFAST(true))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("creating Kerberos client: %w", err)
|
|
}
|
|
defer cl.Destroy()
|
|
|
|
logf("Kerberos: principal=%s@%s", cl.Credentials.UserName(), cl.Credentials.Domain())
|
|
|
|
spnegoClient := spnego.NewClient(cl, nil, "")
|
|
logf("HTTP: GET %s (Negotiate)", url)
|
|
resp, err := spnegoClient.Get(url)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("HTTP request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
logf("HTTP: %s", resp.Status)
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading response: %w", err)
|
|
}
|
|
if resp.StatusCode == http.StatusUnauthorized {
|
|
return nil, fmt.Errorf("server rejected Kerberos token (check SPN and keytab)")
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("server returned %s", resp.Status)
|
|
}
|
|
logf("HTTP: received %d bytes", len(body))
|
|
return body, nil
|
|
}
|
|
|
|
func credFetchBasic(url, username, password string, logf func(string, ...interface{})) ([]byte, error) {
|
|
req, err := http.NewRequest(http.MethodGet, url, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("building request: %w", err)
|
|
}
|
|
req.SetBasicAuth(username, password)
|
|
logf("HTTP: GET %s (Basic, user=%s)", url, username)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("HTTP request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
logf("HTTP: %s", resp.Status)
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading response: %w", err)
|
|
}
|
|
if resp.StatusCode == http.StatusUnauthorized {
|
|
return nil, fmt.Errorf("authentication failed (wrong username or password)")
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("server returned %s: %s", resp.Status, strings.TrimSpace(string(body)))
|
|
}
|
|
logf("HTTP: received %d bytes", len(body))
|
|
return body, nil
|
|
}
|
|
|
|
func credParse(body []byte) (*ExecCredential, error) {
|
|
var ec ExecCredential
|
|
if err := json.Unmarshal(body, &ec); err != nil {
|
|
return nil, fmt.Errorf("parsing server response: %w", err)
|
|
}
|
|
if ec.Status == nil {
|
|
return nil, fmt.Errorf("server returned ExecCredential with no status field")
|
|
}
|
|
return &ec, nil
|
|
}
|
|
|
|
func credPrint(ec *ExecCredential) error {
|
|
return json.NewEncoder(os.Stdout).Encode(ec)
|
|
}
|
|
|
|
// ── Local credential cache ─────────────────────────────────────────────────────
|
|
// Caches ExecCredential JSON in ~/.cache/ward/<sha256(serverURL)>.json.
|
|
// The cache is consulted before contacting ward; a hit avoids a round-trip
|
|
// and, for Kerberos, avoids acquiring a service ticket on every kubectl call.
|
|
|
|
func credCacheDir() string {
|
|
if d := os.Getenv("XDG_CACHE_HOME"); d != "" {
|
|
return filepath.Join(d, "ward")
|
|
}
|
|
home, _ := os.UserHomeDir()
|
|
return filepath.Join(home, ".cache", "ward")
|
|
}
|
|
|
|
func credCacheFile(serverURL string) string {
|
|
h := sha256.Sum256([]byte(serverURL))
|
|
return filepath.Join(credCacheDir(), fmt.Sprintf("%x.json", h))
|
|
}
|
|
|
|
func credReadCache(serverURL string, logf func(string, ...interface{})) (*ExecCredential, bool) {
|
|
path := credCacheFile(serverURL)
|
|
logf("cache: checking %s", path)
|
|
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
logf("cache: miss (%v)", err)
|
|
return nil, false
|
|
}
|
|
var ec ExecCredential
|
|
if err := json.Unmarshal(data, &ec); err != nil {
|
|
logf("cache: corrupt, ignoring (%v)", err)
|
|
return nil, false
|
|
}
|
|
if ec.Status == nil || ec.Status.ExpirationTimestamp == "" {
|
|
logf("cache: no expiry stored, refreshing")
|
|
return nil, false
|
|
}
|
|
expiry, err := time.Parse(time.RFC3339, ec.Status.ExpirationTimestamp)
|
|
if err != nil {
|
|
logf("cache: unparseable expiry %q, refreshing", ec.Status.ExpirationTimestamp)
|
|
return nil, false
|
|
}
|
|
remaining := time.Until(expiry)
|
|
if remaining < 5*time.Minute {
|
|
logf("cache: expiring soon (%v remaining), refreshing", remaining.Truncate(time.Second))
|
|
return nil, false
|
|
}
|
|
logf("cache: hit — cert for expires %s (%v remaining)",
|
|
expiry.Format(time.RFC3339), remaining.Truncate(time.Second))
|
|
return &ec, true
|
|
}
|
|
|
|
func credWriteCache(serverURL string, ec *ExecCredential, logf func(string, ...any)) {
|
|
dir := credCacheDir()
|
|
if err := os.MkdirAll(dir, 0700); err != nil {
|
|
logf("cache: failed to create dir: %v", err)
|
|
return
|
|
}
|
|
data, err := json.Marshal(ec)
|
|
if err != nil {
|
|
logf("cache: marshal failed: %v", err)
|
|
return
|
|
}
|
|
path := credCacheFile(serverURL)
|
|
if err := os.WriteFile(path, data, 0600); err != nil {
|
|
logf("cache: write failed: %v", err)
|
|
return
|
|
}
|
|
logf("cache: saved to %s", path)
|
|
}
|
|
|
|
// ── Kerberos helpers ───────────────────────────────────────────────────────────
|
|
|
|
func krb5ConfigPath() string {
|
|
return cmp.Or(os.Getenv("KRB5_CONFIG"), "/etc/krb5.conf")
|
|
}
|
|
|
|
// ccachePath returns the path to the active Kerberos credential cache.
|
|
// Respects $KRB5CCNAME; strips the "FILE:" prefix if present.
|
|
// Non-file ccache types (API:, KEYRING:, DIR:) are not supported by gokrb5
|
|
// and will produce an error when LoadCCache is called.
|
|
func ccachePath() string {
|
|
if v := os.Getenv("KRB5CCNAME"); v != "" {
|
|
return strings.TrimPrefix(v, "FILE:")
|
|
}
|
|
return fmt.Sprintf("/tmp/krb5cc_%d", os.Getuid())
|
|
}
|
|
|
|
// ── Password prompt ────────────────────────────────────────────────────────────
|
|
|
|
func credPromptPassword(username string) (string, error) {
|
|
terminal, err := os.OpenFile("/dev/tty", os.O_RDWR, 0)
|
|
if err != nil {
|
|
return "", fmt.Errorf(
|
|
"cannot open terminal and $WARD_PASSWORD is not set\n" +
|
|
"hint: run 'kinit' for Kerberos auth, or set $WARD_PASSWORD for non-interactive use")
|
|
}
|
|
|
|
oldState, err := term.MakeRaw(int(terminal.Fd()))
|
|
if err != nil {
|
|
return "", fmt.Errorf("setting terminal raw mode: %w", err)
|
|
}
|
|
|
|
sigCh := make(chan os.Signal, 1)
|
|
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
|
|
go func() {
|
|
<-sigCh
|
|
term.Restore(int(terminal.Fd()), oldState)
|
|
os.Exit(1)
|
|
}()
|
|
|
|
fmt.Fprintf(terminal, "Password for %s: ", username)
|
|
pw, err := term.ReadPassword(int(terminal.Fd()))
|
|
fmt.Fprintf(terminal, "\r\n") // newline after the hidden input
|
|
signal.Stop(sigCh)
|
|
term.Restore(int(terminal.Fd()), oldState)
|
|
if err != nil {
|
|
return "", fmt.Errorf("reading password: %w", err)
|
|
}
|
|
return string(pw), nil
|
|
}
|