This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
@@ -8,55 +8,33 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/term"
|
||||
|
||||
krb5client "github.com/jcmturner/gokrb5/v8/client"
|
||||
"github.com/jcmturner/gokrb5/v8/config"
|
||||
"github.com/jcmturner/gokrb5/v8/credentials"
|
||||
"github.com/jcmturner/gokrb5/v8/spnego"
|
||||
|
||||
"git.shee.sh/james/ward/pkg/handler"
|
||||
)
|
||||
|
||||
func credFetch(server, username string, noKerberos bool, logf func(string, ...any)) (*ExecCredential, error) {
|
||||
func credFetch(server string, noKerberos bool, logf func(string, ...any)) (*handler.ExecCredential, error) {
|
||||
url := strings.TrimRight(server, "/") + "/credential"
|
||||
|
||||
// ── Kerberos SPNEGO ───────────────────────────────────────────────────────
|
||||
if !noKerberos {
|
||||
body, err := credFetchKerberos(url, logf)
|
||||
if err != nil {
|
||||
logf("Kerberos: %v — falling back to Basic auth", err)
|
||||
fmt.Fprintf(os.Stderr, "ward: Kerberos failed (%v); falling back to Basic auth\nhint: run 'kinit' to avoid the password prompt\n", err)
|
||||
logf("Kerberos: %v", err)
|
||||
} else {
|
||||
return credParse(body)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Basic auth ────────────────────────────────────────────────────────────
|
||||
if username == "" {
|
||||
username = os.Getenv("USER")
|
||||
}
|
||||
password := os.Getenv("WARD_PASSWORD")
|
||||
if password == "" {
|
||||
var err error
|
||||
password, err = credPromptPassword(username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
logf("Basic: using password from $WARD_PASSWORD")
|
||||
}
|
||||
|
||||
body, err := credFetchBasic(url, username, password, logf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return credParse(body)
|
||||
return nil, fmt.Errorf("no valid credential — run 'ward login --server %s'", server)
|
||||
}
|
||||
|
||||
func credFetchKerberos(url string, logf func(string, ...any)) ([]byte, error) {
|
||||
@@ -149,8 +127,8 @@ func credFetchBasic(url, username, password string, logf func(string, ...any)) (
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func credParse(body []byte) (*ExecCredential, error) {
|
||||
var ec ExecCredential
|
||||
func credParse(body []byte) (*handler.ExecCredential, error) {
|
||||
var ec handler.ExecCredential
|
||||
if err := json.Unmarshal(body, &ec); err != nil {
|
||||
return nil, fmt.Errorf("parsing server response: %w", err)
|
||||
}
|
||||
@@ -160,7 +138,7 @@ func credParse(body []byte) (*ExecCredential, error) {
|
||||
return &ec, nil
|
||||
}
|
||||
|
||||
func credPrint(ec *ExecCredential) error {
|
||||
func credPrint(ec *handler.ExecCredential) error {
|
||||
return json.NewEncoder(os.Stdout).Encode(ec)
|
||||
}
|
||||
|
||||
@@ -182,7 +160,7 @@ func credCacheFile(serverURL string) string {
|
||||
return filepath.Join(credCacheDir(), fmt.Sprintf("%x.json", h))
|
||||
}
|
||||
|
||||
func credReadCache(serverURL string, logf func(string, ...any)) (*ExecCredential, bool) {
|
||||
func credReadCache(serverURL string, logf func(string, ...any)) (*handler.ExecCredential, bool) {
|
||||
path := credCacheFile(serverURL)
|
||||
logf("cache: checking %s", path)
|
||||
|
||||
@@ -191,7 +169,7 @@ func credReadCache(serverURL string, logf func(string, ...any)) (*ExecCredential
|
||||
logf("cache: miss (%v)", err)
|
||||
return nil, false
|
||||
}
|
||||
var ec ExecCredential
|
||||
var ec handler.ExecCredential
|
||||
if err := json.Unmarshal(data, &ec); err != nil {
|
||||
logf("cache: corrupt, ignoring (%v)", err)
|
||||
return nil, false
|
||||
@@ -215,7 +193,7 @@ func credReadCache(serverURL string, logf func(string, ...any)) (*ExecCredential
|
||||
return &ec, true
|
||||
}
|
||||
|
||||
func credWriteCache(serverURL string, ec *ExecCredential, logf func(string, ...any)) {
|
||||
func credWriteCache(serverURL string, ec *handler.ExecCredential, logf func(string, ...any)) {
|
||||
dir := credCacheDir()
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
logf("cache: failed to create dir: %v", err)
|
||||
@@ -260,37 +238,3 @@ func ccachePath() (string, error) {
|
||||
}
|
||||
return fmt.Sprintf("/tmp/krb5cc_%d", os.Getuid()), nil
|
||||
}
|
||||
|
||||
// ── Password prompt ────────────────────────────────────────────────────────────
|
||||
|
||||
func credPromptPassword(username string) (string, error) {
|
||||
terminal, err := os.OpenFile("/dev/tty", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf(
|
||||
"cannot open terminal and $WARD_PASSWORD is not set\n" +
|
||||
"hint: run 'kinit' for Kerberos auth, or set $WARD_PASSWORD for non-interactive use")
|
||||
}
|
||||
|
||||
oldState, err := term.MakeRaw(int(terminal.Fd()))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("setting terminal raw mode: %w", err)
|
||||
}
|
||||
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigCh
|
||||
term.Restore(int(terminal.Fd()), oldState)
|
||||
os.Exit(1)
|
||||
}()
|
||||
|
||||
fmt.Fprintf(terminal, "Password for %s: ", username)
|
||||
pw, err := term.ReadPassword(int(terminal.Fd()))
|
||||
fmt.Fprintf(terminal, "\r\n") // newline after the hidden input
|
||||
signal.Stop(sigCh)
|
||||
term.Restore(int(terminal.Fd()), oldState)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading password: %w", err)
|
||||
}
|
||||
return string(pw), nil
|
||||
}
|
||||
68
cmd/credential.go
Normal file
68
cmd/credential.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newCredentialCmd() *cobra.Command {
|
||||
var (
|
||||
server string
|
||||
noKerberos bool
|
||||
noCache bool
|
||||
debugFlag bool
|
||||
)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "credential",
|
||||
Short: "kubectl exec credential plugin — serve a cached ExecCredential to kubectl",
|
||||
Long: `Acts as a kubectl exec credential plugin. Returns a cached ExecCredential
|
||||
JSON to kubectl. On a cache miss, silently attempts Kerberos SPNEGO; if that
|
||||
also fails, exits with an error directing the user to run 'ward login'.
|
||||
|
||||
Run 'ward login' once to authenticate and populate the cache. After that,
|
||||
kubectl works silently until the credential expires.
|
||||
|
||||
Debug output goes to stderr (kubectl surfaces this to the terminal):
|
||||
|
||||
WARD_DEBUG=1 kubectl get nodes`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if server == "" {
|
||||
return fmt.Errorf("--server is required")
|
||||
}
|
||||
|
||||
server = normalizeServer(server)
|
||||
|
||||
logf := func(format string, a ...any) {
|
||||
if debugFlag {
|
||||
fmt.Fprintf(os.Stderr, "[ward] "+format+"\n", a...)
|
||||
}
|
||||
}
|
||||
|
||||
if !noCache {
|
||||
if ec, ok := credReadCache(server, logf); ok {
|
||||
return credPrint(ec)
|
||||
}
|
||||
}
|
||||
|
||||
ec, err := credFetch(server, noKerberos, logf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !noCache {
|
||||
credWriteCache(server, ec, logf)
|
||||
}
|
||||
return credPrint(ec)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&server, "server", "", "ward server URL (required)")
|
||||
cmd.Flags().BoolVar(&noKerberos, "no-kerberos", false, "skip Kerberos SPNEGO")
|
||||
cmd.Flags().BoolVar(&noCache, "no-cache", false, "bypass local cache; always fetch a fresh credential")
|
||||
cmd.Flags().BoolVar(&debugFlag, "debug", os.Getenv("WARD_DEBUG") != "", "verbose debug output to stderr (also: $WARD_DEBUG=1)")
|
||||
|
||||
return cmd
|
||||
}
|
||||
139
cmd/login.go
Normal file
139
cmd/login.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"git.shee.sh/james/ward/pkg/handler"
|
||||
"git.shee.sh/james/ward/pkg/kubeconfig"
|
||||
)
|
||||
|
||||
func newLoginCmd() *cobra.Command {
|
||||
var (
|
||||
server string
|
||||
username string
|
||||
contextName string
|
||||
noKerberos bool
|
||||
noSetContext bool
|
||||
debugFlag bool
|
||||
)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "login",
|
||||
Short: "Authenticate to ward and configure kubectl",
|
||||
Long: `Authenticates to the ward server, caches a long-lived credential, and
|
||||
updates ~/.kube/config with the cluster, user (exec plugin), and context.
|
||||
After login, kubectl works silently until the credential expires, at
|
||||
which point you run 'ward login' again.
|
||||
|
||||
Authentication priority:
|
||||
1. Kerberos SPNEGO using the active credential cache (from kinit)
|
||||
2. Password prompt`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if server == "" {
|
||||
return fmt.Errorf("--server is required")
|
||||
}
|
||||
|
||||
server = normalizeServer(server)
|
||||
|
||||
logf := func(format string, a ...any) {
|
||||
if debugFlag {
|
||||
fmt.Fprintf(os.Stderr, "[ward] "+format+"\n", a...)
|
||||
}
|
||||
}
|
||||
|
||||
if username == "" {
|
||||
username = os.Getenv("USER")
|
||||
}
|
||||
|
||||
ec, err := loginFetch(server, username, noKerberos, logf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
credWriteCache(server, ec, logf)
|
||||
|
||||
bootstrap, err := fetchBootstrap(server, username, logf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetching bootstrap kubeconfig: %w", err)
|
||||
}
|
||||
|
||||
if contextName != "" {
|
||||
kubeconfig.RenameContext(bootstrap, contextName)
|
||||
}
|
||||
|
||||
contextApplied := bootstrap.CurrentContext
|
||||
if err := kubeconfig.Merge(bootstrap, !noSetContext); err != nil {
|
||||
return fmt.Errorf("updating kubeconfig: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "ward: logged in — context %q configured in %s\n",
|
||||
contextApplied, kubeconfig.FilePath())
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&server, "server", "", "ward server URL (required)")
|
||||
cmd.Flags().StringVar(&username, "username", "", "username (default: $USER)")
|
||||
cmd.Flags().StringVar(&contextName, "context", "", "kubectl context/cluster name (overrides server default)")
|
||||
cmd.Flags().BoolVar(&noKerberos, "no-kerberos", false, "skip Kerberos; use password auth")
|
||||
cmd.Flags().BoolVar(&noSetContext, "no-set-context", false, "do not set as current-context")
|
||||
cmd.Flags().BoolVar(&debugFlag, "debug", os.Getenv("WARD_DEBUG") != "", "verbose debug output to stderr")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func loginFetch(server, username string, noKerberos bool, logf func(string, ...any)) (*handler.ExecCredential, error) {
|
||||
loginURL := strings.TrimRight(server, "/") + "/credential?login=true"
|
||||
|
||||
if !noKerberos {
|
||||
body, err := credFetchKerberos(loginURL, logf)
|
||||
if err != nil {
|
||||
logf("Kerberos: %v — falling back to password auth", err)
|
||||
fmt.Fprintf(os.Stderr, "ward: Kerberos failed (%v); using password auth\nhint: run 'kinit' to avoid the password prompt next time\n", err)
|
||||
} else {
|
||||
return credParse(body)
|
||||
}
|
||||
}
|
||||
|
||||
password, err := promptPassword(username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := credFetchBasic(loginURL, username, password, logf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return credParse(body)
|
||||
}
|
||||
|
||||
func fetchBootstrap(server, username string, logf func(string, ...any)) (*kubeconfig.KubeConfig, error) {
|
||||
bootstrapURL := strings.TrimRight(server, "/") + "/bootstrap?user=" + url.QueryEscape(username)
|
||||
logf("HTTP: GET %s", bootstrapURL)
|
||||
|
||||
resp, err := http.Get(bootstrapURL) //nolint:gosec // URL derived from user-supplied server flag
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("HTTP request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("server returned %s", resp.Status)
|
||||
}
|
||||
|
||||
var cfg kubeconfig.KubeConfig
|
||||
if err := yaml.Unmarshal(body, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parsing bootstrap kubeconfig: %w", err)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
46
cmd/root.go
Normal file
46
cmd/root.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Execute builds the root command and runs it.
|
||||
func Execute() {
|
||||
root := newRootCmd()
|
||||
if err := root.Execute(); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func newRootCmd() *cobra.Command {
|
||||
root := &cobra.Command{
|
||||
Use: "ward",
|
||||
Short: "Kubernetes credential gateway",
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
}
|
||||
|
||||
root.AddCommand(newServeCmd())
|
||||
root.AddCommand(newCredentialCmd())
|
||||
root.AddCommand(newLoginCmd())
|
||||
|
||||
return root
|
||||
}
|
||||
|
||||
// normalizeServer ensures server has an https:// scheme and a port.
|
||||
// Shared by the credential and login commands.
|
||||
func normalizeServer(server string) string {
|
||||
if !strings.Contains(server, "://") {
|
||||
server = "https://" + server
|
||||
}
|
||||
parts := strings.Split(server, ":")
|
||||
if len(parts) == 2 {
|
||||
server += ":8443"
|
||||
}
|
||||
return server
|
||||
}
|
||||
252
cmd/serve.go
Normal file
252
cmd/serve.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"git.shee.sh/james/ward/pkg/auth"
|
||||
"git.shee.sh/james/ward/pkg/cert"
|
||||
"git.shee.sh/james/ward/pkg/handler"
|
||||
)
|
||||
|
||||
func newServeCmd() *cobra.Command {
|
||||
fqdn := detectFQDN()
|
||||
domain := domainPart(fqdn)
|
||||
realm := strings.ToUpper(domain)
|
||||
|
||||
var (
|
||||
addr string
|
||||
tlsCert string
|
||||
tlsKey string
|
||||
k3sServer string
|
||||
serverCACert string
|
||||
clientCACert string
|
||||
clientCAKey string
|
||||
certDuration time.Duration
|
||||
clusterName string
|
||||
|
||||
ldapDomain string
|
||||
ldapOn bool
|
||||
ldapURI string
|
||||
ldapBindDN string
|
||||
|
||||
loginCertDuration time.Duration
|
||||
|
||||
kerberosOn bool
|
||||
keytabPath string
|
||||
spn string
|
||||
|
||||
htpasswdPath string
|
||||
debugFlag bool
|
||||
)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "serve",
|
||||
Short: "Run the ward authentication server",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
dbg := log.New(io.Discard, "[ward] ", 0)
|
||||
if debugFlag {
|
||||
dbg = log.New(os.Stderr, "[ward] ", log.Ltime)
|
||||
dbg.Print("debug logging enabled")
|
||||
}
|
||||
|
||||
if k3sServer == "" {
|
||||
return fmt.Errorf("--k3s-server is required (e.g. https://k3s.example.com:6443)")
|
||||
}
|
||||
|
||||
// ── LDAP ──────────────────────────────────────────────────────────────────
|
||||
ldapBindPassword := os.Getenv("WARD_LDAP_BIND_PASSWORD")
|
||||
var ldapAuth *auth.LDAPAuth
|
||||
if ldapURI != "" || ldapOn {
|
||||
var la *auth.LDAPAuth
|
||||
var err error
|
||||
if ldapURI != "" {
|
||||
la, err = auth.NewLDAPAuthFromURI(ldapURI, ldapDomain, ldapBindDN, ldapBindPassword, dbg)
|
||||
} else {
|
||||
la, err = auth.NewLDAPAuth(ldapDomain, ldapBindDN, ldapBindPassword, dbg)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("LDAP: %w", err)
|
||||
}
|
||||
ldapAuth = la
|
||||
anon := ldapBindDN == ""
|
||||
log.Printf("LDAP: %s:%d (TLS=%v) domain=%s anon=%v", la.Host(), la.Port(), la.UseTLS(), ldapDomain, anon)
|
||||
}
|
||||
|
||||
if ldapAuth == nil && !kerberosOn && htpasswdPath == "" {
|
||||
return fmt.Errorf("no authentication providers configured: use at least one of --ldap, --ldap-uri, --kerberos, or --htpasswd")
|
||||
}
|
||||
|
||||
// ── Kerberos ──────────────────────────────────────────────────────────────
|
||||
var krbAuth *auth.KerberosAuth
|
||||
if kerberosOn {
|
||||
ka, err := auth.NewKerberosAuth(keytabPath, spn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Kerberos: %w", err)
|
||||
}
|
||||
krbAuth = ka
|
||||
log.Printf("Kerberos: keytab=%s SPN=%s realm=%s", keytabPath, spn, realm)
|
||||
}
|
||||
|
||||
// ── htpasswd ──────────────────────────────────────────────────────────────
|
||||
var htpasswdAuth *auth.HtpasswdAuth
|
||||
if htpasswdPath != "" {
|
||||
ha, err := auth.NewHtpasswdAuth(htpasswdPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("htpasswd: %w", err)
|
||||
}
|
||||
htpasswdAuth = ha
|
||||
log.Printf("htpasswd: %s (%d entries)", htpasswdPath, ha.Len())
|
||||
|
||||
sighup := make(chan os.Signal, 1)
|
||||
signal.Notify(sighup, syscall.SIGHUP)
|
||||
go func() {
|
||||
for range sighup {
|
||||
if err := htpasswdAuth.Reload(); err != nil {
|
||||
log.Printf("SIGHUP: htpasswd reload failed: %v", err)
|
||||
} else {
|
||||
log.Printf("SIGHUP: htpasswd reloaded (%d entries)", htpasswdAuth.Len())
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// ── Handler ───────────────────────────────────────────────────────────────
|
||||
h, err := handler.NewHandler(ldapAuth, krbAuth, htpasswdAuth, &cert.CertConfig{
|
||||
ServerURL: k3sServer,
|
||||
ServerCACert: serverCACert,
|
||||
ClientCACert: clientCACert,
|
||||
ClientCAKey: clientCAKey,
|
||||
Duration: certDuration,
|
||||
LoginDuration: loginCertDuration,
|
||||
ClusterName: clusterName,
|
||||
}, dbg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("handler: %w", err)
|
||||
}
|
||||
|
||||
// ── TLS ───────────────────────────────────────────────────────────────────
|
||||
if _, err := tls.LoadX509KeyPair(tlsCert, tlsKey); err != nil {
|
||||
return fmt.Errorf("TLS: loading certificate: %w", err)
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
c, err := tls.LoadX509KeyPair(tlsCert, tlsKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reloading TLS cert: %w", err)
|
||||
}
|
||||
return &c, nil
|
||||
},
|
||||
}
|
||||
|
||||
ln, err := tls.Listen("tcp", addr, tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen %s: %w", addr, err)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/kubeconfig", h.ServeHTTP)
|
||||
mux.HandleFunc("/credential", h.ServeCredential)
|
||||
mux.HandleFunc("/bootstrap", h.ServeBootstrap)
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
fmt.Fprintf(w, "ward — Kubernetes credential gateway\n\n"+
|
||||
" GET /bootstrap kubeconfig with exec plugin pre-wired (no auth required)\n"+
|
||||
" GET /credential ExecCredential JSON for kubectl exec plugin\n"+
|
||||
" GET /kubeconfig kubeconfig with embedded client certificate\n")
|
||||
})
|
||||
|
||||
srv := &http.Server{
|
||||
Handler: mux,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
log.Printf("ward listening on %s (cert-duration=%s)", addr, certDuration)
|
||||
return srv.Serve(ln)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&addr, "addr", ":8443", "Listen address")
|
||||
cmd.Flags().StringVar(&tlsCert, "tls-cert", fmt.Sprintf("/etc/letsencrypt/live/%s/fullchain.pem", fqdn), "TLS certificate file (Let's Encrypt fullchain)")
|
||||
cmd.Flags().StringVar(&tlsKey, "tls-key", fmt.Sprintf("/etc/letsencrypt/live/%s/privkey.pem", fqdn), "TLS private key file")
|
||||
cmd.Flags().StringVar(&k3sServer, "k3s-server", "", "k3s API server URL written into returned kubeconfigs (required, e.g. https://k3s.example.com:6443)")
|
||||
cmd.Flags().StringVar(&serverCACert, "server-ca-cert", "/var/lib/rancher/k3s/server/tls/server-ca.crt", "k3s server CA certificate (embedded in returned kubeconfig)")
|
||||
cmd.Flags().StringVar(&clientCACert, "client-ca-cert", "/var/lib/rancher/k3s/server/tls/client-ca.crt", "k3s client CA certificate (signs user certs)")
|
||||
cmd.Flags().StringVar(&clientCAKey, "client-ca-key", "/var/lib/rancher/k3s/server/tls/client-ca.key", "k3s client CA key")
|
||||
cmd.Flags().DurationVar(&certDuration, "cert-duration", 24*time.Hour, "Validity period of generated client certificates")
|
||||
cmd.Flags().DurationVar(&loginCertDuration, "login-cert-duration", 168*time.Hour, "Validity period of certificates issued by 'ward login' (default 7 days)")
|
||||
cmd.Flags().StringVar(&clusterName, "cluster-name", firstLabel(domain), "Cluster/context name written into generated kubeconfigs")
|
||||
|
||||
cmd.Flags().StringVar(&ldapDomain, "domain", domain, "Domain for LDAP SRV discovery and Kerberos realm derivation")
|
||||
cmd.Flags().BoolVar(&ldapOn, "ldap", false, "Enable LDAP authentication (auto-discovered via DNS SRV)")
|
||||
cmd.Flags().StringVar(&ldapURI, "ldap-uri", "", "LDAP server URI, e.g. ldaps://ldap.example.com (implies --ldap; overrides DNS SRV)")
|
||||
cmd.Flags().StringVar(&ldapBindDN, "ldap-bind-dn", os.Getenv("WARD_LDAP_BIND_DN"), "LDAP bind DN for search (default: anonymous; env: WARD_LDAP_BIND_DN)")
|
||||
// LDAP bind password is read exclusively from $WARD_LDAP_BIND_PASSWORD to avoid
|
||||
// exposure in process listings.
|
||||
|
||||
cmd.Flags().BoolVar(&kerberosOn, "kerberos", false, "Enable Kerberos SPNEGO authentication (Authorization: Negotiate)")
|
||||
cmd.Flags().StringVar(&keytabPath, "keytab", "/etc/krb5.keytab", "Kerberos service keytab path")
|
||||
cmd.Flags().StringVar(&spn, "spn", "HTTP/"+fqdn, fmt.Sprintf("Kerberos service principal name (SPN) (default realm %s — create with: kadmin: addprinc -randkey HTTP/%s@%s)", realm, fqdn, realm))
|
||||
|
||||
cmd.Flags().StringVar(&htpasswdPath, "htpasswd", "", "Path to an Apache-compatible htpasswd file (bcrypt recommended: htpasswd -B -c file user)")
|
||||
|
||||
cmd.Flags().BoolVar(&debugFlag, "debug", os.Getenv("WARD_DEBUG") != "", "Enable verbose debug logging (also: $WARD_DEBUG=1)")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// detectFQDN returns the fully-qualified domain name of the local host,
|
||||
// falling back to the short hostname if DNS resolution fails.
|
||||
func detectFQDN() string {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
return "localhost"
|
||||
}
|
||||
if strings.Contains(hostname, ".") {
|
||||
return hostname
|
||||
}
|
||||
addrs, err := net.LookupHost(hostname)
|
||||
if err != nil || len(addrs) == 0 {
|
||||
return hostname
|
||||
}
|
||||
names, err := net.LookupAddr(addrs[0])
|
||||
if err != nil || len(names) == 0 {
|
||||
return hostname
|
||||
}
|
||||
return strings.TrimSuffix(names[0], ".")
|
||||
}
|
||||
|
||||
// domainPart strips the first label from a FQDN.
|
||||
// "host.example.com" → "example.com"
|
||||
func domainPart(fqdn string) string {
|
||||
if idx := strings.IndexByte(fqdn, '.'); idx >= 0 {
|
||||
return fqdn[idx+1:]
|
||||
}
|
||||
return fqdn
|
||||
}
|
||||
|
||||
// firstLabel returns the first dot-separated label of a domain name.
|
||||
// "example.com" → "example"
|
||||
func firstLabel(domain string) string {
|
||||
if idx := strings.IndexByte(domain, '.'); idx >= 0 {
|
||||
return domain[:idx]
|
||||
}
|
||||
return domain
|
||||
}
|
||||
40
cmd/tty.go
Normal file
40
cmd/tty.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
func promptPassword(username string) (string, error) {
|
||||
terminal, err := os.OpenFile("/dev/tty", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot open terminal for password prompt\nhint: run 'kinit' for Kerberos auth")
|
||||
}
|
||||
|
||||
oldState, err := term.MakeRaw(int(terminal.Fd()))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("setting terminal raw mode: %w", err)
|
||||
}
|
||||
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigCh
|
||||
term.Restore(int(terminal.Fd()), oldState)
|
||||
os.Exit(1)
|
||||
}()
|
||||
|
||||
fmt.Fprintf(terminal, "Password for %s: ", username)
|
||||
pw, err := term.ReadPassword(int(terminal.Fd()))
|
||||
fmt.Fprintf(terminal, "\r\n")
|
||||
signal.Stop(sigCh)
|
||||
term.Restore(int(terminal.Fd()), oldState)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading password: %w", err)
|
||||
}
|
||||
return string(pw), nil
|
||||
}
|
||||
1
go.mod
1
go.mod
@@ -24,4 +24,5 @@ require (
|
||||
github.com/spf13/pflag v1.0.9 // indirect
|
||||
golang.org/x/net v0.22.0 // indirect
|
||||
golang.org/x/sys v0.18.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
334
main.go
334
main.go
@@ -1,337 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// dbg is a package-level debug logger, active only when --debug is passed.
|
||||
// Defaults to discarding output; serve command points it at stderr when --debug is set.
|
||||
var dbg = log.New(io.Discard, "[ward] ", 0)
|
||||
import "git.shee.sh/james/ward/cmd"
|
||||
|
||||
func main() {
|
||||
root := &cobra.Command{
|
||||
Use: "ward",
|
||||
Short: "Kubernetes credential gateway",
|
||||
SilenceUsage: true,
|
||||
SilenceErrors: true,
|
||||
}
|
||||
|
||||
root.AddCommand(newServeCmd())
|
||||
root.AddCommand(newCredentialCmd())
|
||||
|
||||
if err := root.Execute(); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func newServeCmd() *cobra.Command {
|
||||
fqdn := detectFQDN()
|
||||
domain := domainPart(fqdn)
|
||||
realm := strings.ToUpper(domain)
|
||||
|
||||
var (
|
||||
addr string
|
||||
tlsCert string
|
||||
tlsKey string
|
||||
k3sServer string
|
||||
serverCACert string
|
||||
clientCACert string
|
||||
clientCAKey string
|
||||
certDuration time.Duration
|
||||
clusterName string
|
||||
|
||||
ldapDomain string
|
||||
ldapOn bool
|
||||
ldapURI string
|
||||
ldapBindDN string
|
||||
ldapBindPassword string
|
||||
|
||||
kerberosOn bool
|
||||
keytabPath string
|
||||
spn string
|
||||
|
||||
htpasswdPath string
|
||||
debugFlag bool
|
||||
)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "serve",
|
||||
Short: "Run the ward authentication server",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if debugFlag {
|
||||
dbg = log.New(os.Stderr, "[ward] ", log.Ltime)
|
||||
dbg.Print("debug logging enabled")
|
||||
}
|
||||
|
||||
if k3sServer == "" {
|
||||
return fmt.Errorf("--k3s-server is required (e.g. https://k3s.example.com:6443)")
|
||||
}
|
||||
|
||||
// ── LDAP ──────────────────────────────────────────────────────────────────
|
||||
var ldapAuth *LDAPAuth
|
||||
if ldapURI != "" || ldapOn {
|
||||
var la *LDAPAuth
|
||||
var err error
|
||||
if ldapURI != "" {
|
||||
la, err = NewLDAPAuthFromURI(ldapURI, ldapDomain, ldapBindDN, ldapBindPassword)
|
||||
} else {
|
||||
la, err = NewLDAPAuth(ldapDomain, ldapBindDN, ldapBindPassword)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("LDAP: %w", err)
|
||||
}
|
||||
ldapAuth = la
|
||||
anon := ldapBindDN == ""
|
||||
log.Printf("LDAP: %s:%d (TLS=%v) domain=%s anon=%v", la.host, la.port, la.useTLS, ldapDomain, anon)
|
||||
}
|
||||
|
||||
if ldapAuth == nil && !kerberosOn && htpasswdPath == "" {
|
||||
return fmt.Errorf("no authentication providers configured: use at least one of --ldap, --ldap-uri, --kerberos, or --htpasswd")
|
||||
}
|
||||
|
||||
// ── Kerberos ──────────────────────────────────────────────────────────────
|
||||
var krbAuth *KerberosAuth
|
||||
if kerberosOn {
|
||||
ka, err := NewKerberosAuth(keytabPath, spn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Kerberos: %w", err)
|
||||
}
|
||||
krbAuth = ka
|
||||
log.Printf("Kerberos: keytab=%s SPN=%s realm=%s", keytabPath, spn, realm)
|
||||
}
|
||||
|
||||
// ── htpasswd ──────────────────────────────────────────────────────────────
|
||||
var htpasswdAuth *HtpasswdAuth
|
||||
if htpasswdPath != "" {
|
||||
ha, err := NewHtpasswdAuth(htpasswdPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("htpasswd: %w", err)
|
||||
}
|
||||
htpasswdAuth = ha
|
||||
log.Printf("htpasswd: %s (%d entries)", htpasswdPath, ha.Len())
|
||||
|
||||
sighup := make(chan os.Signal, 1)
|
||||
signal.Notify(sighup, syscall.SIGHUP)
|
||||
go func() {
|
||||
for range sighup {
|
||||
if err := htpasswdAuth.Reload(); err != nil {
|
||||
log.Printf("SIGHUP: htpasswd reload failed: %v", err)
|
||||
} else {
|
||||
log.Printf("SIGHUP: htpasswd reloaded (%d entries)", htpasswdAuth.Len())
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// ── Handler ───────────────────────────────────────────────────────────────
|
||||
h, err := NewHandler(ldapAuth, krbAuth, htpasswdAuth, &CertConfig{
|
||||
ServerURL: k3sServer,
|
||||
ServerCACert: serverCACert,
|
||||
ClientCACert: clientCACert,
|
||||
ClientCAKey: clientCAKey,
|
||||
Duration: certDuration,
|
||||
ClusterName: clusterName,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("handler: %w", err)
|
||||
}
|
||||
|
||||
// ── TLS ───────────────────────────────────────────────────────────────────
|
||||
if _, err := tls.LoadX509KeyPair(tlsCert, tlsKey); err != nil {
|
||||
return fmt.Errorf("TLS: loading certificate: %w", err)
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
cert, err := tls.LoadX509KeyPair(tlsCert, tlsKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reloading TLS cert: %w", err)
|
||||
}
|
||||
return &cert, nil
|
||||
},
|
||||
}
|
||||
|
||||
ln, err := tls.Listen("tcp", addr, tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen %s: %w", addr, err)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/kubeconfig", h.ServeHTTP)
|
||||
mux.HandleFunc("/credential", h.ServeCredential)
|
||||
mux.HandleFunc("/bootstrap", h.ServeBootstrap)
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
fmt.Fprintf(w, "ward — Kubernetes credential gateway\n\n"+
|
||||
" GET /bootstrap kubeconfig with exec plugin pre-wired (no auth required)\n"+
|
||||
" GET /credential ExecCredential JSON for kubectl exec plugin\n"+
|
||||
" GET /kubeconfig kubeconfig with embedded client certificate\n")
|
||||
})
|
||||
|
||||
srv := &http.Server{
|
||||
Handler: mux,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
log.Printf("ward listening on %s (cert-duration=%s)", addr, certDuration)
|
||||
return srv.Serve(ln)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&addr, "addr", ":8443", "Listen address")
|
||||
cmd.Flags().StringVar(&tlsCert, "tls-cert", fmt.Sprintf("/etc/letsencrypt/live/%s/fullchain.pem", fqdn), "TLS certificate file (Let's Encrypt fullchain)")
|
||||
cmd.Flags().StringVar(&tlsKey, "tls-key", fmt.Sprintf("/etc/letsencrypt/live/%s/privkey.pem", fqdn), "TLS private key file")
|
||||
cmd.Flags().StringVar(&k3sServer, "k3s-server", "", "k3s API server URL written into returned kubeconfigs (required, e.g. https://k3s.example.com:6443)")
|
||||
cmd.Flags().StringVar(&serverCACert, "server-ca-cert", "/var/lib/rancher/k3s/server/tls/server-ca.crt", "k3s server CA certificate (embedded in returned kubeconfig)")
|
||||
cmd.Flags().StringVar(&clientCACert, "client-ca-cert", "/var/lib/rancher/k3s/server/tls/client-ca.crt", "k3s client CA certificate (signs user certs)")
|
||||
cmd.Flags().StringVar(&clientCAKey, "client-ca-key", "/var/lib/rancher/k3s/server/tls/client-ca.key", "k3s client CA key")
|
||||
cmd.Flags().DurationVar(&certDuration, "cert-duration", 24*time.Hour, "Validity period of generated client certificates")
|
||||
cmd.Flags().StringVar(&clusterName, "cluster-name", firstLabel(domain), "Cluster/context name written into generated kubeconfigs")
|
||||
|
||||
cmd.Flags().StringVar(&ldapDomain, "domain", domain, "Domain for LDAP SRV discovery and Kerberos realm derivation")
|
||||
cmd.Flags().BoolVar(&ldapOn, "ldap", false, "Enable LDAP authentication (auto-discovered via DNS SRV)")
|
||||
cmd.Flags().StringVar(&ldapURI, "ldap-uri", "", "LDAP server URI, e.g. ldaps://ldap.example.com (implies --ldap; overrides DNS SRV)")
|
||||
cmd.Flags().StringVar(&ldapBindDN, "ldap-bind-dn", os.Getenv("WARD_LDAP_BIND_DN"), "LDAP bind DN for search (default: anonymous; env: WARD_LDAP_BIND_DN)")
|
||||
cmd.Flags().StringVar(&ldapBindPassword, "ldap-bind-password", os.Getenv("WARD_LDAP_BIND_PASSWORD"), "LDAP bind password (env: WARD_LDAP_BIND_PASSWORD; caution: visible in ps output)")
|
||||
|
||||
cmd.Flags().BoolVar(&kerberosOn, "kerberos", false, "Enable Kerberos SPNEGO authentication (Authorization: Negotiate)")
|
||||
cmd.Flags().StringVar(&keytabPath, "keytab", "/etc/krb5.keytab", "Kerberos service keytab path")
|
||||
cmd.Flags().StringVar(&spn, "spn", "HTTP/"+fqdn, fmt.Sprintf("Kerberos service principal name (SPN) (default realm %s — create with: kadmin: addprinc -randkey HTTP/%s@%s)", realm, fqdn, realm))
|
||||
|
||||
cmd.Flags().StringVar(&htpasswdPath, "htpasswd", "", "Path to an Apache-compatible htpasswd file (bcrypt recommended: htpasswd -B -c file user)")
|
||||
|
||||
cmd.Flags().BoolVar(&debugFlag, "debug", os.Getenv("WARD_DEBUG") != "", "Enable verbose debug logging (also: $WARD_DEBUG=1)")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newCredentialCmd() *cobra.Command {
|
||||
var (
|
||||
server string
|
||||
username string
|
||||
noKerberos bool
|
||||
noCache bool
|
||||
debugFlag bool
|
||||
)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "credential",
|
||||
Short: "kubectl exec credential plugin — fetch an ExecCredential from ward",
|
||||
Long: `Acts as a kubectl exec credential plugin. Fetches an ExecCredential JSON
|
||||
from the ward server and prints it to stdout for kubectl to consume.
|
||||
|
||||
Authentication priority:
|
||||
1. Kerberos SPNEGO using the active credential cache (from kinit)
|
||||
2. Basic auth — prompts for password, or reads $WARD_PASSWORD
|
||||
|
||||
Credentials are cached in ~/.cache/ward/ and reused until 5 minutes
|
||||
before expiry, so kubectl invocations are fast after the first call.
|
||||
|
||||
Debug output goes to stderr (kubectl surfaces this to the terminal):
|
||||
|
||||
WARD_DEBUG=1 kubectl get nodes`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if server == "" {
|
||||
return fmt.Errorf("--server is required")
|
||||
}
|
||||
|
||||
// prepend https if no scheme is given, for user convenience
|
||||
if !strings.Contains(server, "://") {
|
||||
server = "https://" + server
|
||||
}
|
||||
|
||||
// append port 8443 if no port is given
|
||||
parts := strings.Split(server, ":")
|
||||
if len(parts) == 2 {
|
||||
server += ":8443"
|
||||
}
|
||||
|
||||
logf := func(format string, a ...any) {
|
||||
if debugFlag {
|
||||
fmt.Fprintf(os.Stderr, "[ward] "+format+"\n", a...)
|
||||
}
|
||||
}
|
||||
|
||||
if !noCache {
|
||||
if ec, ok := credReadCache(server, logf); ok {
|
||||
return credPrint(ec)
|
||||
}
|
||||
}
|
||||
|
||||
ec, err := credFetch(server, username, noKerberos, logf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !noCache {
|
||||
credWriteCache(server, ec, logf)
|
||||
}
|
||||
return credPrint(ec)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&server, "server", "", "ward server URL (required)")
|
||||
cmd.Flags().StringVar(&username, "username", "", "username for Basic auth fallback (default: $USER)")
|
||||
cmd.Flags().BoolVar(&noKerberos, "no-kerberos", false, "skip Kerberos; always use Basic auth")
|
||||
cmd.Flags().BoolVar(&noCache, "no-cache", false, "bypass local cache; always fetch a fresh credential")
|
||||
cmd.Flags().BoolVar(&debugFlag, "debug", os.Getenv("WARD_DEBUG") != "", "verbose debug output to stderr (also: $WARD_DEBUG=1)")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// detectFQDN returns the fully-qualified domain name of the local host,
|
||||
// falling back to the short hostname if DNS resolution fails.
|
||||
func detectFQDN() string {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
return "localhost"
|
||||
}
|
||||
if strings.Contains(hostname, ".") {
|
||||
return hostname
|
||||
}
|
||||
addrs, err := net.LookupHost(hostname)
|
||||
if err != nil || len(addrs) == 0 {
|
||||
return hostname
|
||||
}
|
||||
names, err := net.LookupAddr(addrs[0])
|
||||
if err != nil || len(names) == 0 {
|
||||
return hostname
|
||||
}
|
||||
return strings.TrimSuffix(names[0], ".")
|
||||
}
|
||||
|
||||
// domainPart strips the first label from a FQDN.
|
||||
// "host.example.com" → "example.com"
|
||||
func domainPart(fqdn string) string {
|
||||
if idx := strings.IndexByte(fqdn, '.'); idx >= 0 {
|
||||
return fqdn[idx+1:]
|
||||
}
|
||||
return fqdn
|
||||
}
|
||||
|
||||
// firstLabel returns the first dot-separated label of a domain name.
|
||||
// "example.com" → "example"
|
||||
func firstLabel(domain string) string {
|
||||
if idx := strings.IndexByte(domain, '.'); idx >= 0 {
|
||||
return domain[:idx]
|
||||
}
|
||||
return domain
|
||||
cmd.Execute()
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -1,8 +1,9 @@
|
||||
package main
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
@@ -19,6 +20,7 @@ type LDAPAuth struct {
|
||||
useTLS bool // true = LDAPS (TLS from the start); false = STARTTLS on port 389
|
||||
bindDN string // empty = anonymous bind
|
||||
bindPassword string
|
||||
log *log.Logger
|
||||
}
|
||||
|
||||
// NewLDAPAuth discovers the LDAP server for domain via the standard _ldap._tcp
|
||||
@@ -26,7 +28,7 @@ type LDAPAuth struct {
|
||||
// port 389 uses STARTTLS; anything else (typically 636) uses LDAPS.
|
||||
// _ldaps._tcp is not an IANA-registered SRV type and is not consulted.
|
||||
// bindDN and bindPassword are used for the search bind; both empty = anonymous.
|
||||
func NewLDAPAuth(domain, bindDN, bindPassword string) (*LDAPAuth, error) {
|
||||
func NewLDAPAuth(domain, bindDN, bindPassword string, dbg *log.Logger) (*LDAPAuth, error) {
|
||||
host, port, err := discoverLDAP(domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -38,6 +40,7 @@ func NewLDAPAuth(domain, bindDN, bindPassword string) (*LDAPAuth, error) {
|
||||
useTLS: port != 389,
|
||||
bindDN: bindDN,
|
||||
bindPassword: bindPassword,
|
||||
log: dbg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -46,7 +49,7 @@ func NewLDAPAuth(domain, bindDN, bindPassword string) (*LDAPAuth, error) {
|
||||
// Port defaults to 389 for ldap:// and 636 for ldaps:// if not specified.
|
||||
// domain is still required for base-DN derivation and UPN construction.
|
||||
// bindDN and bindPassword are used for the search bind; both empty = anonymous.
|
||||
func NewLDAPAuthFromURI(rawURI, domain, bindDN, bindPassword string) (*LDAPAuth, error) {
|
||||
func NewLDAPAuthFromURI(rawURI, domain, bindDN, bindPassword string, dbg *log.Logger) (*LDAPAuth, error) {
|
||||
u, err := url.Parse(rawURI)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid LDAP URI %q: %w", rawURI, err)
|
||||
@@ -78,9 +81,19 @@ func NewLDAPAuthFromURI(rawURI, domain, bindDN, bindPassword string) (*LDAPAuth,
|
||||
useTLS: useTLS,
|
||||
bindDN: bindDN,
|
||||
bindPassword: bindPassword,
|
||||
log: dbg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Host returns the LDAP server hostname (used for logging).
|
||||
func (a *LDAPAuth) Host() string { return a.host }
|
||||
|
||||
// Port returns the LDAP server port (used for logging).
|
||||
func (a *LDAPAuth) Port() int { return a.port }
|
||||
|
||||
// UseTLS reports whether TLS is in use (used for logging).
|
||||
func (a *LDAPAuth) UseTLS() bool { return a.useTLS }
|
||||
|
||||
func discoverLDAP(domain string) (host string, port int, err error) {
|
||||
_, addrs, err := net.LookupSRV("ldap", "tcp", domain)
|
||||
if err != nil || len(addrs) == 0 {
|
||||
@@ -93,10 +106,10 @@ func (a *LDAPAuth) connect() (*ldap.Conn, error) {
|
||||
addr := fmt.Sprintf("%s:%d", a.host, a.port)
|
||||
tlsConfig := &tls.Config{ServerName: a.host}
|
||||
if a.useTLS {
|
||||
dbg.Printf("LDAP: dialing TLS %s", addr)
|
||||
a.log.Printf("LDAP: dialing TLS %s", addr)
|
||||
return ldap.DialTLS("tcp", addr, tlsConfig)
|
||||
}
|
||||
dbg.Printf("LDAP: dialing %s + STARTTLS", addr)
|
||||
a.log.Printf("LDAP: dialing %s + STARTTLS", addr)
|
||||
conn, err := ldap.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -157,27 +170,27 @@ func (a *LDAPAuth) Authenticate(username, password string) (groups []string, err
|
||||
return nil, fmt.Errorf("LDAP search bind failed: %w", err)
|
||||
}
|
||||
|
||||
dbg.Printf("LDAP: searching for user %q", username)
|
||||
a.log.Printf("LDAP: searching for user %q", username)
|
||||
userDN, err := a.findUserDN(conn, username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
// Verify the password by binding as the user.
|
||||
dbg.Printf("LDAP: binding as %s", userDN)
|
||||
a.log.Printf("LDAP: binding as %s", userDN)
|
||||
if err := conn.Bind(userDN, password); err != nil {
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
dbg.Printf("LDAP: bind OK for %s", userDN)
|
||||
a.log.Printf("LDAP: bind OK for %s", userDN)
|
||||
|
||||
// Re-bind as service account for group lookup — the user may lack read access.
|
||||
if err := a.searchBind(conn); err != nil {
|
||||
dbg.Printf("LDAP: re-bind for group lookup failed: %v — skipping groups", err)
|
||||
a.log.Printf("LDAP: re-bind for group lookup failed: %v — skipping groups", err)
|
||||
return nil, nil // auth succeeded; group lookup is best-effort
|
||||
}
|
||||
|
||||
groups = a.lookupGroups(conn, username, userDN)
|
||||
dbg.Printf("LDAP: groups for %s: %v", username, groups)
|
||||
a.log.Printf("LDAP: groups for %s: %v", username, groups)
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package cert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -19,12 +19,13 @@ import (
|
||||
|
||||
// CertConfig holds paths and settings for certificate and kubeconfig generation.
|
||||
type CertConfig struct {
|
||||
ServerURL string // written as the cluster.server in the returned kubeconfig
|
||||
ServerCACert string // path — embedded in the kubeconfig so kubectl can verify the API server
|
||||
ClientCACert string // path — signs the per-user client certificates
|
||||
ClientCAKey string // path
|
||||
Duration time.Duration // validity period for generated client certs
|
||||
ClusterName string // name used for the cluster/context in generated kubeconfigs
|
||||
ServerURL string // written as the cluster.server in the returned kubeconfig
|
||||
ServerCACert string // path — embedded in the kubeconfig so kubectl can verify the API server
|
||||
ClientCACert string // path — signs the per-user client certificates
|
||||
ClientCAKey string // path
|
||||
Duration time.Duration // validity period for generated client certs
|
||||
LoginDuration time.Duration // validity period for certs issued by 'ward login'
|
||||
ClusterName string // name used for the cluster/context in generated kubeconfigs
|
||||
}
|
||||
|
||||
// KubeconfigGenerator loads the k3s CAs once and issues per-user kubeconfigs on demand.
|
||||
@@ -55,6 +56,9 @@ func NewKubeconfigGenerator(cfg *CertConfig) (*KubeconfigGenerator, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Cfg returns the CertConfig (used by handler to read Duration/LoginDuration).
|
||||
func (g *KubeconfigGenerator) Cfg() *CertConfig { return g.cfg }
|
||||
|
||||
// Credential holds the raw PEM blobs and expiry for a generated client certificate.
|
||||
// Used both for kubeconfig generation and the /credential exec-plugin endpoint.
|
||||
type Credential struct {
|
||||
@@ -65,8 +69,8 @@ type Credential struct {
|
||||
|
||||
// GenerateCredential signs a fresh client certificate and returns the raw PEM data.
|
||||
// Use this when you need the cert material directly (e.g. the exec credential plugin).
|
||||
func (g *KubeconfigGenerator) GenerateCredential(username string, groups []string) (*Credential, error) {
|
||||
certPEM, keyPEM, err := g.signClientCert(username, groups)
|
||||
func (g *KubeconfigGenerator) GenerateCredential(username string, groups []string, duration time.Duration) (*Credential, error) {
|
||||
certPEM, keyPEM, err := g.signClientCert(username, groups, duration)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("signing cert for %s: %w", username, err)
|
||||
}
|
||||
@@ -82,7 +86,7 @@ func (g *KubeconfigGenerator) GenerateCredential(username string, groups []strin
|
||||
// groups are embedded as the certificate's Organisation field, which Kubernetes reads
|
||||
// as RBAC group memberships.
|
||||
func (g *KubeconfigGenerator) Generate(username string, groups []string) ([]byte, error) {
|
||||
cred, err := g.GenerateCredential(username, groups)
|
||||
cred, err := g.GenerateCredential(username, groups, g.cfg.Duration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -97,7 +101,7 @@ func (g *KubeconfigGenerator) Generate(username string, groups []string) ([]byte
|
||||
}
|
||||
|
||||
// signClientCert issues an ECDSA P-256 client certificate signed by the k3s client CA.
|
||||
func (g *KubeconfigGenerator) signClientCert(username string, groups []string) (certPEM, keyPEM []byte, err error) {
|
||||
func (g *KubeconfigGenerator) signClientCert(username string, groups []string, duration time.Duration) (certPEM, keyPEM []byte, err error) {
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("generating key: %w", err)
|
||||
@@ -116,7 +120,7 @@ func (g *KubeconfigGenerator) signClientCert(username string, groups []string) (
|
||||
Organization: groups, // Kubernetes maps these to RBAC groups
|
||||
},
|
||||
NotBefore: now.Add(-5 * time.Minute), // tolerate minor clock skew
|
||||
NotAfter: now.Add(g.cfg.Duration),
|
||||
NotAfter: now.Add(duration),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
BasicConstraintsValid: true,
|
||||
@@ -186,11 +190,11 @@ func loadCA(certFile, keyFile string) (*x509.Certificate, crypto.PrivateKey, err
|
||||
func (g *KubeconfigGenerator) GenerateBootstrap(wardURL, username string) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
err := bootstrapTmpl.Execute(&buf, map[string]string{
|
||||
"Server": g.cfg.ServerURL,
|
||||
"ServerCA": g.serverCA,
|
||||
"WardURL": wardURL,
|
||||
"Username": username,
|
||||
"Cluster": g.cfg.ClusterName,
|
||||
"Server": g.cfg.ServerURL,
|
||||
"ServerCA": g.serverCA,
|
||||
"WardURL": wardURL,
|
||||
"Username": username,
|
||||
"Cluster": g.cfg.ClusterName,
|
||||
})
|
||||
return buf.Bytes(), err
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -8,6 +8,9 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.shee.sh/james/ward/pkg/auth"
|
||||
"git.shee.sh/james/ward/pkg/cert"
|
||||
)
|
||||
|
||||
// ExecCredential is the JSON structure kubectl expects from an exec credential plugin.
|
||||
@@ -28,23 +31,24 @@ type ExecCredentialStatus struct {
|
||||
// Handler wires together authentication providers and kubeconfig generation.
|
||||
// At least one provider must be non-nil.
|
||||
type Handler struct {
|
||||
ldap *LDAPAuth
|
||||
krb *KerberosAuth
|
||||
htpasswd *HtpasswdAuth
|
||||
gen *KubeconfigGenerator
|
||||
ldap *auth.LDAPAuth
|
||||
krb *auth.KerberosAuth
|
||||
htpasswd *auth.HtpasswdAuth
|
||||
gen *cert.KubeconfigGenerator
|
||||
log *log.Logger
|
||||
}
|
||||
|
||||
// NewHandler validates that at least one auth provider is configured, then
|
||||
// loads the k3s CA files and returns a ready Handler.
|
||||
func NewHandler(ldap *LDAPAuth, krb *KerberosAuth, htpasswd *HtpasswdAuth, cfg *CertConfig) (*Handler, error) {
|
||||
if ldap == nil && krb == nil && htpasswd == nil {
|
||||
func NewHandler(ldapAuth *auth.LDAPAuth, krbAuth *auth.KerberosAuth, htpasswdAuth *auth.HtpasswdAuth, cfg *cert.CertConfig, dbg *log.Logger) (*Handler, error) {
|
||||
if ldapAuth == nil && krbAuth == nil && htpasswdAuth == nil {
|
||||
return nil, fmt.Errorf("no authentication providers configured: enable at least one of LDAP, --kerberos, or --htpasswd")
|
||||
}
|
||||
gen, err := NewKubeconfigGenerator(cfg)
|
||||
gen, err := cert.NewKubeconfigGenerator(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Handler{ldap: ldap, krb: krb, htpasswd: htpasswd, gen: gen}, nil
|
||||
return &Handler{ldap: ldapAuth, krb: krbAuth, htpasswd: htpasswdAuth, gen: gen, log: dbg}, nil
|
||||
}
|
||||
|
||||
// ServeHTTP handles GET /kubeconfig — returns a kubeconfig YAML on success.
|
||||
@@ -83,10 +87,13 @@ func (h *Handler) ServeCredential(w http.ResponseWriter, r *http.Request) {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
dbg.Printf("issuing exec credential for %q groups=%v", username, groups)
|
||||
log.Printf("issuing exec credential for %q groups=%v", username, groups)
|
||||
|
||||
cred, err := h.gen.GenerateCredential(username, groups)
|
||||
duration := h.gen.Cfg().Duration
|
||||
if r.URL.Query().Get("login") == "true" {
|
||||
duration = h.gen.Cfg().LoginDuration
|
||||
}
|
||||
cred, err := h.gen.GenerateCredential(username, groups, duration)
|
||||
if err != nil {
|
||||
log.Printf("credential generation failed for %q: %v", username, err)
|
||||
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||
@@ -160,18 +167,18 @@ func (h *Handler) authenticate(w http.ResponseWriter, r *http.Request) (string,
|
||||
|
||||
switch {
|
||||
case h.krb != nil && strings.HasPrefix(authHeader, "Negotiate "):
|
||||
dbg.Printf("auth: trying Kerberos SPNEGO")
|
||||
h.log.Printf("auth: trying Kerberos SPNEGO")
|
||||
username, err := h.krb.Authenticate(w, r)
|
||||
if err != nil {
|
||||
log.Printf("Kerberos auth failed: %v", err)
|
||||
h.sendChallenge(w, true, h.ldap != nil || h.htpasswd != nil)
|
||||
return "", nil, false
|
||||
}
|
||||
dbg.Printf("auth: Kerberos OK, user=%q", username)
|
||||
h.log.Printf("auth: Kerberos OK, user=%q", username)
|
||||
var groups []string
|
||||
if h.ldap != nil {
|
||||
groups = h.ldap.LookupGroups(username)
|
||||
dbg.Printf("auth: LDAP group lookup for %q → %v", username, groups)
|
||||
h.log.Printf("auth: LDAP group lookup for %q → %v", username, groups)
|
||||
}
|
||||
return username, groups, true
|
||||
|
||||
@@ -181,18 +188,18 @@ func (h *Handler) authenticate(w http.ResponseWriter, r *http.Request) (string,
|
||||
h.sendChallenge(w, h.krb != nil, true)
|
||||
return "", nil, false
|
||||
}
|
||||
dbg.Printf("auth: trying Basic for user=%q", user)
|
||||
h.log.Printf("auth: trying Basic for user=%q", user)
|
||||
groups, err := h.authenticateBasic(user, password)
|
||||
if err != nil {
|
||||
log.Printf("Basic auth failed for %q: %v", user, err)
|
||||
h.sendChallenge(w, h.krb != nil, true)
|
||||
return "", nil, false
|
||||
}
|
||||
dbg.Printf("auth: Basic OK, user=%q groups=%v", user, groups)
|
||||
h.log.Printf("auth: Basic OK, user=%q groups=%v", user, groups)
|
||||
return user, groups, true
|
||||
|
||||
default:
|
||||
dbg.Printf("auth: no Authorization header, sending challenges")
|
||||
h.log.Printf("auth: no Authorization header, sending challenges")
|
||||
h.sendChallenge(w, h.krb != nil, h.ldap != nil || h.htpasswd != nil)
|
||||
return "", nil, false
|
||||
}
|
||||
@@ -206,7 +213,7 @@ func (h *Handler) authenticateBasic(username, password string) ([]string, error)
|
||||
if err == nil {
|
||||
return groups, nil
|
||||
}
|
||||
dbg.Printf("LDAP auth failed for %q: %v", username, err)
|
||||
h.log.Printf("LDAP auth failed for %q: %v", username, err)
|
||||
if h.htpasswd == nil {
|
||||
return nil, err
|
||||
}
|
||||
165
pkg/kubeconfig/kubeconfig.go
Normal file
165
pkg/kubeconfig/kubeconfig.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package kubeconfig
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// KubeConfig is the top-level kubeconfig structure.
|
||||
type KubeConfig struct {
|
||||
APIVersion string `yaml:"apiVersion"`
|
||||
Kind string `yaml:"kind"`
|
||||
Preferences map[string]any `yaml:"preferences,omitempty"`
|
||||
Clusters []NamedCluster `yaml:"clusters"`
|
||||
Users []NamedUser `yaml:"users"`
|
||||
Contexts []NamedContext `yaml:"contexts"`
|
||||
CurrentContext string `yaml:"current-context,omitempty"`
|
||||
}
|
||||
|
||||
// NamedCluster is a named cluster entry.
|
||||
type NamedCluster struct {
|
||||
Name string `yaml:"name"`
|
||||
Cluster ClusterData `yaml:"cluster"`
|
||||
}
|
||||
|
||||
// ClusterData holds the cluster connection details.
|
||||
type ClusterData struct {
|
||||
Server string `yaml:"server"`
|
||||
CertificateAuthorityData string `yaml:"certificate-authority-data,omitempty"`
|
||||
}
|
||||
|
||||
// NamedUser is a named user entry.
|
||||
type NamedUser struct {
|
||||
Name string `yaml:"name"`
|
||||
User UserData `yaml:"user"`
|
||||
}
|
||||
|
||||
// UserData holds user credentials.
|
||||
type UserData struct {
|
||||
Exec *ExecData `yaml:"exec,omitempty"`
|
||||
}
|
||||
|
||||
// ExecData holds exec plugin configuration.
|
||||
type ExecData struct {
|
||||
APIVersion string `yaml:"apiVersion"`
|
||||
Command string `yaml:"command"`
|
||||
Args []string `yaml:"args,omitempty"`
|
||||
InteractiveMode string `yaml:"interactiveMode,omitempty"`
|
||||
}
|
||||
|
||||
// NamedContext is a named context entry.
|
||||
type NamedContext struct {
|
||||
Name string `yaml:"name"`
|
||||
Context ContextData `yaml:"context"`
|
||||
}
|
||||
|
||||
// ContextData holds context details.
|
||||
type ContextData struct {
|
||||
Cluster string `yaml:"cluster"`
|
||||
User string `yaml:"user"`
|
||||
}
|
||||
|
||||
// FilePath returns the path to the active kubeconfig file.
|
||||
// If KUBECONFIG is set to a colon-separated list, the first entry is returned.
|
||||
func FilePath() string {
|
||||
if k := os.Getenv("KUBECONFIG"); k != "" {
|
||||
if idx := strings.IndexByte(k, os.PathListSeparator); idx >= 0 {
|
||||
return k[:idx]
|
||||
}
|
||||
return k
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, ".kube", "config")
|
||||
}
|
||||
|
||||
// RenameContext renames the cluster and context (but not the user) in cfg.
|
||||
// The bootstrap template uses the cluster name as both the cluster and context
|
||||
// name; the user name is the actual username and is left unchanged.
|
||||
func RenameContext(cfg *KubeConfig, newName string) {
|
||||
oldName := cfg.CurrentContext
|
||||
if oldName == newName {
|
||||
return
|
||||
}
|
||||
for i := range cfg.Clusters {
|
||||
if cfg.Clusters[i].Name == oldName {
|
||||
cfg.Clusters[i].Name = newName
|
||||
}
|
||||
}
|
||||
for i := range cfg.Contexts {
|
||||
if cfg.Contexts[i].Name == oldName {
|
||||
cfg.Contexts[i].Name = newName
|
||||
cfg.Contexts[i].Context.Cluster = newName
|
||||
}
|
||||
}
|
||||
cfg.CurrentContext = newName
|
||||
}
|
||||
|
||||
// Merge merges incoming into the kubeconfig file at FilePath().
|
||||
// If setContext is true, the current-context is updated to incoming.CurrentContext.
|
||||
func Merge(incoming *KubeConfig, setContext bool) error {
|
||||
path := FilePath()
|
||||
|
||||
existing := &KubeConfig{APIVersion: "v1", Kind: "Config"}
|
||||
if data, err := os.ReadFile(path); err == nil {
|
||||
if err := yaml.Unmarshal(data, existing); err != nil {
|
||||
return fmt.Errorf("parsing existing kubeconfig: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, c := range incoming.Clusters {
|
||||
existing.Clusters = upsertCluster(existing.Clusters, c)
|
||||
}
|
||||
for _, u := range incoming.Users {
|
||||
existing.Users = upsertUser(existing.Users, u)
|
||||
}
|
||||
for _, ctx := range incoming.Contexts {
|
||||
existing.Contexts = upsertContext(existing.Contexts, ctx)
|
||||
}
|
||||
if setContext {
|
||||
existing.CurrentContext = incoming.CurrentContext
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(existing)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling kubeconfig: %w", err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
|
||||
return fmt.Errorf("creating kubeconfig directory: %w", err)
|
||||
}
|
||||
return os.WriteFile(path, data, 0600)
|
||||
}
|
||||
|
||||
func upsertCluster(list []NamedCluster, item NamedCluster) []NamedCluster {
|
||||
for i, c := range list {
|
||||
if c.Name == item.Name {
|
||||
list[i] = item
|
||||
return list
|
||||
}
|
||||
}
|
||||
return append(list, item)
|
||||
}
|
||||
|
||||
func upsertUser(list []NamedUser, item NamedUser) []NamedUser {
|
||||
for i, u := range list {
|
||||
if u.Name == item.Name {
|
||||
list[i] = item
|
||||
return list
|
||||
}
|
||||
}
|
||||
return append(list, item)
|
||||
}
|
||||
|
||||
func upsertContext(list []NamedContext, item NamedContext) []NamedContext {
|
||||
for i, c := range list {
|
||||
if c.Name == item.Name {
|
||||
list[i] = item
|
||||
return list
|
||||
}
|
||||
}
|
||||
return append(list, item)
|
||||
}
|
||||
Reference in New Issue
Block a user