diff --git a/apis/v1/cocoonset_types.go b/apis/v1/cocoonset_types.go index eaf9503..48709ff 100644 --- a/apis/v1/cocoonset_types.go +++ b/apis/v1/cocoonset_types.go @@ -142,9 +142,13 @@ type CocoonSetStatus struct { DesiredToolboxes int32 `json:"desiredToolboxes"` // +optional + // +listType=map + // +listMapKey=slot Agents []AgentStatus `json:"agents,omitempty"` // +optional + // +listType=map + // +listMapKey=name Toolboxes []ToolboxStatus `json:"toolboxes,omitempty"` // +optional diff --git a/apis/v1/crds/cocoonset.cocoonstack.io_cocoonsets.yaml b/apis/v1/crds/cocoonset.cocoonstack.io_cocoonsets.yaml index 544c03a..037fba6 100644 --- a/apis/v1/crds/cocoonset.cocoonstack.io_cocoonsets.yaml +++ b/apis/v1/crds/cocoonset.cocoonstack.io_cocoonsets.yaml @@ -443,6 +443,9 @@ spec: - slot type: object type: array + x-kubernetes-list-map-keys: + - slot + x-kubernetes-list-type: map conditions: items: description: Condition contains details for one aspect of the current @@ -562,6 +565,9 @@ spec: - name type: object type: array + x-kubernetes-list-map-keys: + - name + x-kubernetes-list-type: map type: object type: object served: true diff --git a/apis/v1/enums.go b/apis/v1/enums.go index 9bbfe1e..5793d5a 100644 --- a/apis/v1/enums.go +++ b/apis/v1/enums.go @@ -1,5 +1,10 @@ package v1 +import ( + "cmp" + "slices" +) + const ( AgentModeClone AgentMode = "clone" AgentModeRun AgentMode = "run" @@ -32,6 +37,17 @@ const ( BackendFirecracker Backend = "firecracker" ) +// Per-type valid-value tables. Keep ordering aligned with the const +// block above so a new enum member is one edit in each place. +var ( + agentModeValid = []AgentMode{AgentModeClone, AgentModeRun} + toolboxModeValid = []ToolboxMode{ToolboxModeRun, ToolboxModeClone, ToolboxModeStatic} + osTypeValid = []OSType{OSLinux, OSWindows, OSAndroid} + snapshotPolicyValid = []SnapshotPolicy{SnapshotPolicyAlways, SnapshotPolicyMainOnly, SnapshotPolicyNever} + connTypeValid = []ConnType{ConnTypeSSH, ConnTypeRDP, ConnTypeVNC, ConnTypeADB} + backendValid = []Backend{BackendCloudHypervisor, BackendFirecracker} +) + // AgentMode defines the mode of an agent VM. // +kubebuilder:validation:Enum=clone;run type AgentMode string @@ -66,78 +82,39 @@ type ConnType string type Backend string // IsValid reports whether m is a recognized AgentMode value. -func (m AgentMode) IsValid() bool { - return m == AgentModeClone || m == AgentModeRun -} +func (m AgentMode) IsValid() bool { return slices.Contains(agentModeValid, m) } // Default returns m when set, otherwise AgentModeClone. -func (m AgentMode) Default() AgentMode { - if m == "" { - return AgentModeClone - } - return m -} +func (m AgentMode) Default() AgentMode { return cmp.Or(m, AgentModeClone) } // IsValid reports whether m is a recognized ToolboxMode value. -func (m ToolboxMode) IsValid() bool { - return m == ToolboxModeRun || m == ToolboxModeClone || m == ToolboxModeStatic -} +func (m ToolboxMode) IsValid() bool { return slices.Contains(toolboxModeValid, m) } // Default returns m when set, otherwise ToolboxModeRun. -func (m ToolboxMode) Default() ToolboxMode { - if m == "" { - return ToolboxModeRun - } - return m -} +func (m ToolboxMode) Default() ToolboxMode { return cmp.Or(m, ToolboxModeRun) } // IsValid reports whether o is a recognized OSType value. -func (o OSType) IsValid() bool { - return o == OSLinux || o == OSWindows || o == OSAndroid -} +func (o OSType) IsValid() bool { return slices.Contains(osTypeValid, o) } // Default returns o when set, otherwise OSLinux. -func (o OSType) Default() OSType { - if o == "" { - return OSLinux - } - return o -} +func (o OSType) Default() OSType { return cmp.Or(o, OSLinux) } // IsValid reports whether p is a recognized SnapshotPolicy value. -func (p SnapshotPolicy) IsValid() bool { - return p == SnapshotPolicyAlways || p == SnapshotPolicyMainOnly || p == SnapshotPolicyNever -} +func (p SnapshotPolicy) IsValid() bool { return slices.Contains(snapshotPolicyValid, p) } // Default returns p when set, otherwise SnapshotPolicyAlways. -func (p SnapshotPolicy) Default() SnapshotPolicy { - if p == "" { - return SnapshotPolicyAlways - } - return p -} +func (p SnapshotPolicy) Default() SnapshotPolicy { return cmp.Or(p, SnapshotPolicyAlways) } // IsValid reports whether c is a recognized ConnType value. -func (c ConnType) IsValid() bool { - return c == ConnTypeSSH || c == ConnTypeRDP || c == ConnTypeVNC || c == ConnTypeADB -} +func (c ConnType) IsValid() bool { return slices.Contains(connTypeValid, c) } // Default returns c unchanged. Unlike the other enums, ConnType has no // static default: an empty value signals "infer from OS and runtime" // (see meta.ConnectionType), so this method exists only for API symmetry. -func (c ConnType) Default() ConnType { - return c -} +func (c ConnType) Default() ConnType { return c } // IsValid reports whether b is a recognized Backend value. -func (b Backend) IsValid() bool { - return b == BackendCloudHypervisor || b == BackendFirecracker -} +func (b Backend) IsValid() bool { return slices.Contains(backendValid, b) } // Default returns b when set, otherwise BackendCloudHypervisor. -func (b Backend) Default() Backend { - if b == "" { - return BackendCloudHypervisor - } - return b -} +func (b Backend) Default() Backend { return cmp.Or(b, BackendCloudHypervisor) } diff --git a/auth/session.go b/auth/session.go index 89da380..0fe4dc0 100644 --- a/auth/session.go +++ b/auth/session.go @@ -9,6 +9,7 @@ import ( "encoding/json" "fmt" "strings" + "time" ) // Session holds the claims embedded in an HMAC-signed cookie. @@ -32,7 +33,7 @@ func SignSession(sess Session, key []byte) (string, error) { } // VerifySession validates the HMAC signature and decodes the session. -// Returns nil and false if the signature is invalid or decoding fails. +// Exp == 0 means "no expiry". func VerifySession(cookie string, key []byte) (*Session, bool) { payload, sig, ok := strings.Cut(cookie, ".") if !ok { @@ -52,5 +53,8 @@ func VerifySession(cookie string, key []byte) (*Session, bool) { if json.Unmarshal(data, sess) != nil { return nil, false } + if sess.Exp != 0 && sess.Exp <= time.Now().Unix() { + return nil, false + } return sess, true } diff --git a/auth/session_test.go b/auth/session_test.go index d809651..a2777ba 100644 --- a/auth/session_test.go +++ b/auth/session_test.go @@ -60,6 +60,34 @@ func TestVerifySessionWrongKey(t *testing.T) { } } +func TestVerifySessionRejectsExpired(t *testing.T) { + t.Parallel() + + key := []byte("test-secret-key-32-bytes-long!!!") + expired := Session{User: "dave", Exp: time.Now().Add(-time.Hour).Unix()} + cookie, err := SignSession(expired, key) + if err != nil { + t.Fatalf("SignSession: %v", err) + } + if _, ok := VerifySession(cookie, key); ok { + t.Error("expected expired cookie to fail") + } +} + +func TestVerifySessionZeroExpAccepted(t *testing.T) { + t.Parallel() + + key := []byte("test-secret-key-32-bytes-long!!!") + // Exp == 0 means "no expiry" — must remain accepted. + cookie, err := SignSession(Session{User: "eve"}, key) + if err != nil { + t.Fatalf("SignSession: %v", err) + } + if _, ok := VerifySession(cookie, key); !ok { + t.Error("expected session with zero Exp to be accepted") + } +} + func TestRandomState(t *testing.T) { t.Parallel() diff --git a/auth/state.go b/auth/state.go index ca27304..d3c52b3 100644 --- a/auth/state.go +++ b/auth/state.go @@ -3,12 +3,16 @@ package auth import ( "crypto/rand" "encoding/hex" + "fmt" ) // RandomState returns a cryptographically random 32-character hex string // suitable for OAuth state parameters and CSRF nonces. +// Panics on crypto/rand failure — a weak nonce silently breaks CSRF. func RandomState() string { b := make([]byte, 16) - _, _ = rand.Read(b) //nolint:errcheck // crypto/rand.Read never fails on supported platforms + if _, err := rand.Read(b); err != nil { + panic(fmt.Sprintf("crypto/rand.Read: %v", err)) + } return hex.EncodeToString(b) } diff --git a/k8s/netip.go b/k8s/netip.go index f03e88e..843772b 100644 --- a/k8s/netip.go +++ b/k8s/netip.go @@ -1,12 +1,22 @@ package k8s -import "net" +import ( + "errors" + "fmt" + "net" +) -// DetectNodeIP returns the first non-loopback IPv4 address, or "127.0.0.1" if none found. -func DetectNodeIP() string { +// ErrNoNodeIP is returned when no non-loopback IPv4 address is +// reachable. Callers pick the fallback — auto-substituting localhost +// would mask misconfigured network namespaces. +var ErrNoNodeIP = errors.New("no non-loopback IPv4 address found") + +// DetectNodeIP returns the first non-loopback IPv4 address, or +// ErrNoNodeIP if none exists. +func DetectNodeIP() (string, error) { addrs, err := net.InterfaceAddrs() if err != nil { - return localhost + return "", fmt.Errorf("list interface addresses: %w", err) } for _, addr := range addrs { ipNet, ok := addr.(*net.IPNet) @@ -14,8 +24,8 @@ func DetectNodeIP() string { continue } if ip4 := ipNet.IP.To4(); ip4 != nil { - return ip4.String() + return ip4.String(), nil } } - return localhost + return "", ErrNoNodeIP } diff --git a/k8s/tls.go b/k8s/tls.go index 8ca523f..86e51fa 100644 --- a/k8s/tls.go +++ b/k8s/tls.go @@ -1,6 +1,7 @@ package k8s import ( + "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -8,45 +9,92 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "errors" "fmt" + "io/fs" "math/big" "net" "os" "time" + + "github.com/projecteru2/core/log" ) const localhost = "127.0.0.1" -// LoadOrGenerateCert loads a TLS keypair from disk, falling back to a self-signed cert. -// Returns a source label for logging ("disk " or "self-signed"). -func LoadOrGenerateCert(certPath, keyPath, hostname, ip string) (tls.Certificate, string, error) { - if certPath != "" && keyPath != "" { - if _, err := os.Stat(certPath); err == nil { - cert, err := tls.LoadX509KeyPair(certPath, keyPath) - if err != nil { - return tls.Certificate{}, "", fmt.Errorf("load tls keypair %s: %w", certPath, err) - } - return cert, fmt.Sprintf("disk %s", certPath), nil - } +// LoadOrGenerateCert loads a TLS keypair from disk, falling back to a +// self-signed cert when paths are empty, the cert is missing, or the +// cert is expired. Returns a source label for logging. +func LoadOrGenerateCert(ctx context.Context, certPath, keyPath, hostname, ip string) (tls.Certificate, string, error) { + cert, source, err := tryLoadDiskCert(ctx, certPath, keyPath) + if err != nil { + return tls.Certificate{}, "", err } - cert, err := GenerateSelfSignedCert(hostname, ip) + if source != "" { + return cert, source, nil + } + cert, err = GenerateSelfSignedCert(hostname, ip) if err != nil { return tls.Certificate{}, "", fmt.Errorf("generate self-signed cert: %w", err) } return cert, "self-signed", nil } -// GenerateSelfSignedCert creates an in-memory ECDSA P-256 self-signed cert for hostname and ip. +// tryLoadDiskCert returns ("", "", nil) when the caller should fall +// back to self-signed (paths empty, cert missing, or expired) and an +// error only when a configured keypair fails to load. +func tryLoadDiskCert(ctx context.Context, certPath, keyPath string) (tls.Certificate, string, error) { + if certPath == "" || keyPath == "" { + return tls.Certificate{}, "", nil + } + if _, err := os.Stat(certPath); err != nil { + if errors.Is(err, fs.ErrNotExist) { + return tls.Certificate{}, "", nil + } + return tls.Certificate{}, "", fmt.Errorf("stat tls cert %s: %w", certPath, err) + } + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return tls.Certificate{}, "", fmt.Errorf("load tls keypair %s: %w", certPath, err) + } + if isCertExpired(ctx, cert, certPath) { + return tls.Certificate{}, "", nil + } + return cert, fmt.Sprintf("disk %s", certPath), nil +} + +// isCertExpired returns true when the leaf cert is past NotAfter. +// Parse failures are warned and treated as "not expired". +func isCertExpired(ctx context.Context, cert tls.Certificate, certPath string) bool { + logger := log.WithFunc("k8s.LoadOrGenerateCert") + if len(cert.Certificate) == 0 { + return false + } + parsed, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + logger.Warnf(ctx, "parse disk cert %s: %v (keeping cert)", certPath, err) + return false + } + if time.Now().After(parsed.NotAfter) { + logger.Warnf(ctx, "disk cert %s expired at %s, falling back to self-signed", certPath, parsed.NotAfter.Format(time.RFC3339)) + return true + } + return false +} + +// GenerateSelfSignedCert creates an in-memory ECDSA P-256 self-signed +// cert for hostname and ip. func GenerateSelfSignedCert(hostname, ip string) (tls.Certificate, error) { priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { return tls.Certificate{}, err } + now := time.Now() template := x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{CommonName: hostname}, - NotBefore: time.Now(), - NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour), + NotBefore: now, + NotAfter: now.Add(10 * 365 * 24 * time.Hour), KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, DNSNames: []string{hostname, "localhost"}, diff --git a/k8s/tls_test.go b/k8s/tls_test.go index f534680..b66cf96 100644 --- a/k8s/tls_test.go +++ b/k8s/tls_test.go @@ -42,7 +42,7 @@ func TestGenerateSelfSignedCertIsParseable(t *testing.T) { } func TestLoadOrGenerateCertFallsBackToSelfSigned(t *testing.T) { - cert, source, err := LoadOrGenerateCert("/does/not/exist.crt", "/does/not/exist.key", "host", "10.0.0.1") + cert, source, err := LoadOrGenerateCert(t.Context(), "/does/not/exist.crt", "/does/not/exist.key", "host", "10.0.0.1") if err != nil { t.Fatalf("fallback: %v", err) } @@ -91,7 +91,7 @@ func TestLoadOrGenerateCertLoadsFromDisk(t *testing.T) { t.Fatalf("write key: %v", err) } - _, source, err := LoadOrGenerateCert(certPath, keyPath, "host", "10.0.0.1") + _, source, err := LoadOrGenerateCert(t.Context(), certPath, keyPath, "host", "10.0.0.1") if err != nil { t.Fatalf("load: %v", err) } @@ -100,8 +100,58 @@ func TestLoadOrGenerateCertLoadsFromDisk(t *testing.T) { } } +func TestLoadOrGenerateCertExpiredDiskCertFallsBack(t *testing.T) { + dir := t.TempDir() + certPath := filepath.Join(dir, "tls.crt") + keyPath := filepath.Join(dir, "tls.key") + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("gen key: %v", err) + } + // NotAfter in the past forces the expiry branch. + tmpl := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "host"}, + NotBefore: time.Now().Add(-2 * time.Hour), + NotAfter: time.Now().Add(-time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + IPAddresses: []net.IP{net.ParseIP("10.0.0.1")}, + } + certDER, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv) + if err != nil { + t.Fatalf("create cert: %v", err) + } + privDER, err := x509.MarshalECPrivateKey(priv) + if err != nil { + t.Fatalf("marshal key: %v", err) + } + if err := os.WriteFile(certPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}), 0o600); err != nil { + t.Fatalf("write cert: %v", err) + } + if err := os.WriteFile(keyPath, pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privDER}), 0o600); err != nil { + t.Fatalf("write key: %v", err) + } + + _, source, err := LoadOrGenerateCert(t.Context(), certPath, keyPath, "host", "10.0.0.1") + if err != nil { + t.Fatalf("load: %v", err) + } + if source != "self-signed" { + t.Errorf("expected expired disk cert to fall back to self-signed, got source = %q", source) + } +} + func TestDetectNodeIPReturnsSomething(t *testing.T) { - if got := DetectNodeIP(); got == "" { - t.Errorf("DetectNodeIP returned empty string") + // CI hosts are expected to have at least one non-loopback IPv4 + // interface; skip when none is present so the test reflects the + // environment rather than masking it as a pass. + got, err := DetectNodeIP() + if err != nil { + t.Skipf("no non-loopback IPv4 on this host: %v", err) + } + if got == "" { + t.Errorf("DetectNodeIP returned empty string with no error") } } diff --git a/log/log.go b/log/log.go index 0476576..6fe65cd 100644 --- a/log/log.go +++ b/log/log.go @@ -3,19 +3,21 @@ package log import ( "context" + "fmt" "os" corelog "github.com/projecteru2/core/log" "github.com/projecteru2/core/types" ) -// Setup initializes the core logger from envVar (default "info"). Fatals on failure. -func Setup(ctx context.Context, envVar string) { +// Setup initializes the core logger from envVar (default "info"). +func Setup(ctx context.Context, envVar string) error { level := os.Getenv(envVar) if level == "" { level = "info" } if err := corelog.SetupLog(ctx, &types.ServerLogConfig{Level: level}, ""); err != nil { - corelog.WithFunc("cocooncommon.log.Setup").Fatalf(ctx, err, "setup log: %v", err) + return fmt.Errorf("setup log: %w", err) } + return nil } diff --git a/meta/connection.go b/meta/connection.go new file mode 100644 index 0000000..0ef620a --- /dev/null +++ b/meta/connection.go @@ -0,0 +1,23 @@ +package meta + +import ( + cocoonv1 "github.com/cocoonstack/cocoon-common/apis/v1" +) + +// ConnectionType returns the connection protocol. A non-empty override +// wins over OS-based inference (e.g. Linux + xrdp → rdp). +func ConnectionType(osType string, hasVNCPort bool, override string) string { + if override != "" { + return override + } + switch { + case hasVNCPort: + return string(cocoonv1.ConnTypeVNC) + case osType == string(cocoonv1.OSAndroid): + return string(cocoonv1.ConnTypeADB) + case osType == string(cocoonv1.OSWindows): + return string(cocoonv1.ConnTypeRDP) + default: + return string(cocoonv1.ConnTypeSSH) + } +} diff --git a/meta/keys.go b/meta/keys.go new file mode 100644 index 0000000..42a4516 --- /dev/null +++ b/meta/keys.go @@ -0,0 +1,86 @@ +package meta + +const ( + // APIVersion is the apiVersion string for CocoonSet resources. + APIVersion = "cocoonset.cocoonstack.io/v1" + // KindCocoonSet is the kind string for CocoonSet resources. + KindCocoonSet = "CocoonSet" + + // TolerationKey is the virtual-kubelet provider key used to gate cocoon pods onto vk-cocoon nodes. + TolerationKey = "virtual-kubelet.io/provider" + + // LabelCocoonSet stamps a pod with its owning CocoonSet name. + LabelCocoonSet = "cocoonset.cocoonstack.io/name" + // LabelRole stamps a pod with its role (main / sub-agent / toolbox). + LabelRole = "cocoonset.cocoonstack.io/role" + // LabelSlot stamps a pod with its zero-based agent slot index. + LabelSlot = "cocoonset.cocoonstack.io/slot" + + // LabelNodePool selects which cocoon node pool a pod should land on. + LabelNodePool = "cocoonstack.io/pool" + // DefaultNodePool is the pool name used when LabelNodePool is unset. + DefaultNodePool = "default" + + // AnnotationMode declares the VM provisioning mode (clone / run / static). + AnnotationMode = "cocoonset.cocoonstack.io/mode" + // AnnotationImage carries the VM image reference. + AnnotationImage = "cocoonset.cocoonstack.io/image" + // AnnotationStorage carries the VM root volume size (resource.Quantity). + AnnotationStorage = "cocoonset.cocoonstack.io/storage" + // AnnotationManaged marks a VM as cocoon-managed ("true") versus user-managed/static. + AnnotationManaged = "cocoonset.cocoonstack.io/managed" + // AnnotationOS carries the guest OS family (linux / windows / android). + AnnotationOS = "cocoonset.cocoonstack.io/os" + // AnnotationSnapshotPolicy carries the per-pod snapshot policy. + AnnotationSnapshotPolicy = "cocoonset.cocoonstack.io/snapshot-policy" + // AnnotationNetwork carries the cluster network to attach the VM to. + AnnotationNetwork = "cocoonset.cocoonstack.io/network" + // AnnotationForcePull bypasses the image cache when set to "true". + AnnotationForcePull = "cocoonset.cocoonstack.io/force-pull" + // AnnotationCocoonSetGeneration carries the CocoonSet generation stamped at scheduling time. + AnnotationCocoonSetGeneration = "cocoonset.cocoonstack.io/generation" + + // AnnotationVMID carries the runtime VM identifier vk-cocoon assigns after creation. + AnnotationVMID = "vm.cocoonstack.io/id" + // AnnotationVMName carries the deterministic VM name the operator builds from namespace/deployment/slot. + AnnotationVMName = "vm.cocoonstack.io/name" + // AnnotationIP carries the VM's primary IPv4 address. + AnnotationIP = "vm.cocoonstack.io/ip" + // AnnotationVNCPort carries the VM's VNC port when one is exposed. + AnnotationVNCPort = "vm.cocoonstack.io/vnc-port" + // AnnotationHibernate signals "hibernate this VM" when set to "true". + AnnotationHibernate = "vm.cocoonstack.io/hibernate" + // AnnotationForkFrom names a VM to fork the new VM from. + AnnotationForkFrom = "vm.cocoonstack.io/fork-from" + // AnnotationCloneFromDir names a host directory to clone the VM image from (vk-cocoon-specific). + AnnotationCloneFromDir = "vm.cocoonstack.io/clone-from-dir" + // AnnotationConnType overrides the connection protocol inferred from OS/runtime. + AnnotationConnType = "vm.cocoonstack.io/conn-type" + // AnnotationBackend selects the hypervisor backend (cloud-hypervisor / firecracker). + AnnotationBackend = "vm.cocoonstack.io/backend" + // AnnotationNoDirectIO disables O_DIRECT on writable disks when set to "true" (cloud-hypervisor only). + AnnotationNoDirectIO = "vm.cocoonstack.io/no-direct-io" + // AnnotationProbePort overrides the default ICMP readiness probe with a TCP port check. + AnnotationProbePort = "vm.cocoonstack.io/probe-port" + // AnnotationLifecycleState carries the vk-cocoon-reported lifecycle state. + AnnotationLifecycleState = "vm.cocoonstack.io/lifecycle-state" + // AnnotationLifecycleObservedGeneration carries the CocoonSet generation observed by vk-cocoon. + AnnotationLifecycleObservedGeneration = "vm.cocoonstack.io/lifecycle-observed-generation" + // AnnotationLifecycleStateMessage carries an optional message accompanying the lifecycle state. + AnnotationLifecycleStateMessage = "vm.cocoonstack.io/lifecycle-state-message" + + // RoleMain identifies the main agent VM (slot 0). + RoleMain = "main" + // RoleSubAgent identifies a sub-agent VM (slot > 0). + RoleSubAgent = "sub-agent" + // RoleToolbox identifies a toolbox VM. + RoleToolbox = "toolbox" + + // HibernateSnapshotTag names the snapshot tag used for hibernation. + HibernateSnapshotTag = "hibernate" + // DefaultSnapshotTag names the default snapshot tag. + DefaultSnapshotTag = "latest" + + // annotationTrue is the canonical "true" string for bool-valued annotations. + annotationTrue = "true" +) diff --git a/meta/lifecycle.go b/meta/lifecycle.go index 2865568..93db7e1 100644 --- a/meta/lifecycle.go +++ b/meta/lifecycle.go @@ -6,10 +6,6 @@ import ( corev1 "k8s.io/api/core/v1" ) -// LifecycleState is the typed contract for the lifecycle-state annotation -// vk-cocoon publishes on a Pod. -type LifecycleState string - const ( LifecycleStateCreating LifecycleState = "creating" LifecycleStateReady LifecycleState = "ready" @@ -18,29 +14,31 @@ const ( LifecycleStateFailed LifecycleState = "failed" ) +var terminalStates = map[LifecycleState]struct{}{ + LifecycleStateReady: {}, + LifecycleStateHibernated: {}, + LifecycleStateFailed: {}, +} + +// LifecycleState is the typed contract for the lifecycle-state +// annotation vk-cocoon publishes on a Pod. +type LifecycleState string + // IsTerminal reports whether s is a state a client would wait for. func (s LifecycleState) IsTerminal() bool { - switch s { - case LifecycleStateReady, LifecycleStateHibernated, LifecycleStateFailed: - return true - } - return false + _, ok := terminalStates[s] + return ok } // LifecycleStatus is the full triple (state, observed-generation, message). -// Annotations is the source of truth for what gets written; Apply -// consumes the same map in-memory and Snapshot derives a comparison -// key from the same fields. type LifecycleStatus struct { State LifecycleState ObservedGeneration int64 Message string } -// Annotations returns the lifecycle annotation map for s. nil entries -// signal "delete this key" — pass to k8s.AnnotationsMergePatch to wrap -// into a `metadata.annotations` merge-patch body, or iterate directly -// to mutate an in-memory pod (see Apply). +// Annotations returns the lifecycle annotation map for a merge patch. +// Nil values signal "delete this key". func (s LifecycleStatus) Annotations() map[string]any { annos := map[string]any{ AnnotationLifecycleState: string(s.State), @@ -54,20 +52,20 @@ func (s LifecycleStatus) Annotations() map[string]any { return annos } -// Apply writes Annotations into the pod's annotations, deleting keys -// whose value is nil. Empty message clears the annotation so a stale -// failure reason cannot tail into the next lifecycle. +// Apply writes the status onto pod annotations. Empty message clears +// the annotation so a stale failure reason cannot tail into the next +// lifecycle. func (s LifecycleStatus) Apply(pod *corev1.Pod) { if pod == nil { return } a := ensurePodAnnotations(pod) - for key, val := range s.Annotations() { - if val == nil { - delete(a, key) - continue - } - a[key] = val.(string) + a[AnnotationLifecycleState] = string(s.State) + a[AnnotationLifecycleObservedGeneration] = strconv.FormatInt(s.ObservedGeneration, 10) + if s.Message == "" { + delete(a, AnnotationLifecycleStateMessage) + } else { + a[AnnotationLifecycleStateMessage] = s.Message } } @@ -97,15 +95,14 @@ func ReadLifecycleState(pod *corev1.Pod) LifecycleState { return LifecycleState(pod.Annotations[AnnotationLifecycleState]) } -// ReadLifecycleObservedGeneration reads the observed-generation annotation. -// Missing or unparseable returns 0 — callers treat it as "not observed yet". +// ReadLifecycleObservedGeneration reads the observed-generation +// annotation; missing or unparseable returns 0. func ReadLifecycleObservedGeneration(pod *corev1.Pod) int64 { return readInt64Annotation(pod, AnnotationLifecycleObservedGeneration) } -// ReadCocoonSetGeneration reads the CocoonSet generation stamped by -// cocoon-operator. vk-cocoon writes it back as observed-generation — -// counter-based completion is not subject to wallclock skew. +// ReadCocoonSetGeneration reads the CocoonSet generation stamped on the +// pod by cocoon-operator. func ReadCocoonSetGeneration(pod *corev1.Pod) int64 { return readInt64Annotation(pod, AnnotationCocoonSetGeneration) } @@ -119,8 +116,6 @@ func StampCocoonSetGeneration(pod *corev1.Pod, generation int64) { a[AnnotationCocoonSetGeneration] = strconv.FormatInt(generation, 10) } -// readInt64Annotation parses an int64-valued annotation, returning 0 -// when missing or unparseable. func readInt64Annotation(pod *corev1.Pod, key string) int64 { if pod == nil { return 0 diff --git a/meta/meta.go b/meta/meta.go index 8d31dd6..b7465e8 100644 --- a/meta/meta.go +++ b/meta/meta.go @@ -1,157 +1,3 @@ -// Package meta defines shared metadata keys and naming rules used across Cocoon components. +// Package meta defines shared metadata keys and naming rules used +// across Cocoon components. package meta - -import ( - "slices" - "strconv" - "strings" - - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - cocoonv1 "github.com/cocoonstack/cocoon-common/apis/v1" -) - -const ( - APIVersion = "cocoonset.cocoonstack.io/v1" - KindCocoonSet = "CocoonSet" - - TolerationKey = "virtual-kubelet.io/provider" - - LabelCocoonSet = "cocoonset.cocoonstack.io/name" - LabelRole = "cocoonset.cocoonstack.io/role" - LabelSlot = "cocoonset.cocoonstack.io/slot" - - LabelNodePool = "cocoonstack.io/pool" - DefaultNodePool = "default" - LabelManagedBy = "app.kubernetes.io/managed-by" - - AnnotationMode = "cocoonset.cocoonstack.io/mode" - AnnotationImage = "cocoonset.cocoonstack.io/image" - AnnotationStorage = "cocoonset.cocoonstack.io/storage" - AnnotationManaged = "cocoonset.cocoonstack.io/managed" - AnnotationOS = "cocoonset.cocoonstack.io/os" - AnnotationSnapshotPolicy = "cocoonset.cocoonstack.io/snapshot-policy" - AnnotationNetwork = "cocoonset.cocoonstack.io/network" - AnnotationForcePull = "cocoonset.cocoonstack.io/force-pull" - AnnotationCocoonSetGeneration = "cocoonset.cocoonstack.io/generation" - - AnnotationVMID = "vm.cocoonstack.io/id" - AnnotationVMName = "vm.cocoonstack.io/name" - AnnotationIP = "vm.cocoonstack.io/ip" - AnnotationVNCPort = "vm.cocoonstack.io/vnc-port" - AnnotationHibernate = "vm.cocoonstack.io/hibernate" - AnnotationForkFrom = "vm.cocoonstack.io/fork-from" - AnnotationCloneFromDir = "vm.cocoonstack.io/clone-from-dir" - AnnotationConnType = "vm.cocoonstack.io/conn-type" - AnnotationBackend = "vm.cocoonstack.io/backend" - AnnotationNoDirectIO = "vm.cocoonstack.io/no-direct-io" - AnnotationProbePort = "vm.cocoonstack.io/probe-port" - AnnotationLifecycleState = "vm.cocoonstack.io/lifecycle-state" - AnnotationLifecycleObservedGeneration = "vm.cocoonstack.io/lifecycle-observed-generation" - AnnotationLifecycleStateMessage = "vm.cocoonstack.io/lifecycle-state-message" - - RoleMain = "main" - RoleSubAgent = "sub-agent" - RoleToolbox = "toolbox" - - HibernateSnapshotTag = "hibernate" - DefaultSnapshotTag = "latest" - - annotationTrue = "true" -) - -// HasCocoonToleration reports whether the toleration list includes the virtual-kubelet provider key. -func HasCocoonToleration(tolerations []corev1.Toleration) bool { - return slices.ContainsFunc(tolerations, func(t corev1.Toleration) bool { - return t.Key == TolerationKey - }) -} - -// IsOwnedByCocoonSet reports whether any owner reference is a CocoonSet. -func IsOwnedByCocoonSet(ownerRefs []metav1.OwnerReference) bool { - return slices.ContainsFunc(ownerRefs, func(ref metav1.OwnerReference) bool { - return ref.Kind == KindCocoonSet - }) -} - -// OwnerDeploymentName extracts the deployment name from a ReplicaSet owner reference. -func OwnerDeploymentName(ownerRefs []metav1.OwnerReference) string { - for _, ref := range ownerRefs { - if ref.Kind != "ReplicaSet" { - continue - } - if before, _, ok := lastCut(ref.Name, "-"); ok { - return before - } - } - return "" -} - -// VMNameForDeployment builds a deterministic VM name from a deployment and slot index. -func VMNameForDeployment(namespace, deployment string, slot int) string { - return "vk-" + namespace + "-" + deployment + "-" + strconv.Itoa(slot) -} - -// VMNameForPod builds a deterministic VM name from a pod name. -func VMNameForPod(namespace, podName string) string { - return "vk-" + namespace + "-" + podName -} - -// ExtractSlotFromVMName parses the trailing slot index from a VM name, or -1 if absent. -func ExtractSlotFromVMName(vmName string) int { - _, after, ok := lastCut(vmName, "-") - if !ok { - return -1 - } - n, err := strconv.Atoi(after) - if err != nil { - return -1 - } - return n -} - -// MainAgentVMName replaces the slot suffix with 0. Non-slot names are returned unchanged. -func MainAgentVMName(vmName string) string { - if ExtractSlotFromVMName(vmName) < 0 { - return vmName - } - before, _, _ := lastCut(vmName, "-") - return before + "-0" -} - -// InferRoleFromVMName returns RoleMain for slot 0, RoleSubAgent otherwise. -func InferRoleFromVMName(vmName string) string { - if ExtractSlotFromVMName(vmName) == 0 { - return RoleMain - } - return RoleSubAgent -} - -// ConnectionType returns the connection protocol. A non-empty override -// (typically AnnotationConnType) wins over OS-based inference, so a Linux -// image running xrdp can advertise rdp without faking its OS field. -func ConnectionType(osType string, hasVNCPort bool, override string) string { - if override != "" { - return override - } - switch { - case hasVNCPort: - return string(cocoonv1.ConnTypeVNC) - case osType == "android": - return string(cocoonv1.ConnTypeADB) - case osType == "windows": - return string(cocoonv1.ConnTypeRDP) - default: - return string(cocoonv1.ConnTypeSSH) - } -} - -// lastCut is like strings.Cut but splits at the last occurrence of sep. -func lastCut(s, sep string) (before, after string, found bool) { - idx := strings.LastIndex(s, sep) - if idx < 0 { - return s, "", false - } - return s[:idx], s[idx+len(sep):], true -} diff --git a/meta/meta_test.go b/meta/meta_test.go index c093a0f..87c5471 100644 --- a/meta/meta_test.go +++ b/meta/meta_test.go @@ -20,14 +20,6 @@ func TestVMNamingHelpers(t *testing.T) { if got := ExtractSlotFromVMName("vk-prod-toolbox"); got != -1 { t.Fatalf("expected non-slot vm name to return -1, got %d", got) } - if got := MainAgentVMName("vk-prod-demo-2"); got != "vk-prod-demo-0" { - t.Fatalf("main agent name mismatch: got %q", got) - } - // A pod-style name (no slot suffix) must be returned unchanged — - // the trailing dash inside the name is not a slot separator. - if got := MainAgentVMName("vk-prod-toolbox"); got != "vk-prod-toolbox" { - t.Fatalf("MainAgentVMName must not coerce non-slot names, got %q", got) - } } func TestInferRoleFromVMName(t *testing.T) { @@ -39,6 +31,95 @@ func TestInferRoleFromVMName(t *testing.T) { } } +func TestExtractAgentSlot(t *testing.T) { + cases := []struct { + name string + ns string + cocoonSet string + vmName string + want int + }{ + { + name: "main agent", + ns: "prod", + cocoonSet: "demo", + vmName: "vk-prod-demo-0", + want: 0, + }, + { + name: "sub-agent", + ns: "prod", + cocoonSet: "demo", + vmName: "vk-prod-demo-3", + want: 3, + }, + { + // The legacy ExtractSlotFromVMName would mis-read this as + // slot 2 because it splits at the last dash. ExtractAgentSlot + // rejects it because the suffix after the agent prefix + // contains a dash. + name: "toolbox with trailing digit is not an agent slot", + ns: "prod", + cocoonSet: "demo", + vmName: "vk-prod-demo-db-2", + want: -1, + }, + { + name: "toolbox without trailing digit", + ns: "prod", + cocoonSet: "demo", + vmName: "vk-prod-demo-toolbox", + want: -1, + }, + { + name: "different cocoonset", + ns: "prod", + cocoonSet: "demo", + vmName: "vk-prod-other-0", + want: -1, + }, + { + name: "different namespace", + ns: "prod", + cocoonSet: "demo", + vmName: "vk-staging-demo-0", + want: -1, + }, + { + name: "non-vk prefix", + ns: "prod", + cocoonSet: "demo", + vmName: "prod-demo-0", + want: -1, + }, + } + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + if got := ExtractAgentSlot(tt.ns, tt.cocoonSet, tt.vmName); got != tt.want { + t.Errorf("ExtractAgentSlot(%q,%q,%q) = %d, want %d", tt.ns, tt.cocoonSet, tt.vmName, got, tt.want) + } + }) + } +} + +func TestMainAgentVMNameFor(t *testing.T) { + if got := MainAgentVMNameFor("prod", "demo"); got != "vk-prod-demo-0" { + t.Errorf("got %q, want %q", got, "vk-prod-demo-0") + } +} + +func TestInferRoleFromAgentSlot(t *testing.T) { + if got := InferRoleFromAgentSlot(0); got != RoleMain { + t.Errorf("slot 0 = %q, want %q", got, RoleMain) + } + if got := InferRoleFromAgentSlot(7); got != RoleSubAgent { + t.Errorf("slot 7 = %q, want %q", got, RoleSubAgent) + } + if got := InferRoleFromAgentSlot(-1); got != RoleToolbox { + t.Errorf("slot -1 = %q, want %q", got, RoleToolbox) + } +} + func TestConnectionType(t *testing.T) { cases := []struct { name string @@ -64,17 +145,56 @@ func TestConnectionType(t *testing.T) { } func TestOwnerDeploymentName(t *testing.T) { - ownerRefs := []metav1.OwnerReference{ - {Kind: "ReplicaSet", Name: "demo-7b7c9d9d5f"}, + cases := []struct { + name string + owners []metav1.OwnerReference + want string + wantOK bool + }{ + { + name: "replicaset with hash suffix", + owners: []metav1.OwnerReference{{Kind: "ReplicaSet", Name: "demo-7b7c9d9d5f"}}, + want: "demo", + wantOK: true, + }, + { + name: "no owners", + owners: nil, + wantOK: false, + }, + { + name: "non-replicaset owner", + owners: []metav1.OwnerReference{{Kind: "Deployment", Name: "demo"}}, + wantOK: false, + }, + { + name: "replicaset with no hash suffix", + owners: []metav1.OwnerReference{{Kind: "ReplicaSet", Name: "demo"}}, + wantOK: false, + }, } - if got := OwnerDeploymentName(ownerRefs); got != "demo" { - t.Fatalf("deployment name mismatch: got %q", got) + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got, ok := OwnerDeploymentName(tt.owners) + if ok != tt.wantOK { + t.Fatalf("ok = %v, want %v", ok, tt.wantOK) + } + if got != tt.want { + t.Errorf("name = %q, want %q", got, tt.want) + } + }) } } -func TestHasCocoonToleration(t *testing.T) { +func TestHasCocoonTolerationKey(t *testing.T) { tolerations := []corev1.Toleration{{Key: TolerationKey}} - if !HasCocoonToleration(tolerations) { + if !HasCocoonTolerationKey(tolerations) { t.Fatalf("expected toleration to be detected") } + if HasCocoonTolerationKey(nil) { + t.Errorf("expected nil tolerations to be rejected") + } + if HasCocoonTolerationKey([]corev1.Toleration{{Key: "other"}}) { + t.Errorf("expected unrelated toleration to be rejected") + } } diff --git a/meta/owner.go b/meta/owner.go new file mode 100644 index 0000000..9d0138a --- /dev/null +++ b/meta/owner.go @@ -0,0 +1,38 @@ +package meta + +import ( + "slices" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// HasCocoonTolerationKey reports whether tolerations include an entry +// whose Key matches TolerationKey. Operator/Value/Effect are ignored — +// the cocoon-webhook gate is intentionally permissive. +func HasCocoonTolerationKey(tolerations []corev1.Toleration) bool { + return slices.ContainsFunc(tolerations, func(t corev1.Toleration) bool { + return t.Key == TolerationKey + }) +} + +// IsOwnedByCocoonSet reports whether any owner reference is a CocoonSet. +func IsOwnedByCocoonSet(ownerRefs []metav1.OwnerReference) bool { + return slices.ContainsFunc(ownerRefs, func(ref metav1.OwnerReference) bool { + return ref.Kind == KindCocoonSet + }) +} + +// OwnerDeploymentName extracts the deployment name from a ReplicaSet +// owner reference. Returns ok=false when absent or unparseable. +func OwnerDeploymentName(ownerRefs []metav1.OwnerReference) (string, bool) { + for _, ref := range ownerRefs { + if ref.Kind != "ReplicaSet" { + continue + } + if before, _, ok := lastCut(ref.Name, "-"); ok { + return before, true + } + } + return "", false +} diff --git a/meta/vmname.go b/meta/vmname.go new file mode 100644 index 0000000..0ca58ad --- /dev/null +++ b/meta/vmname.go @@ -0,0 +1,94 @@ +package meta + +import ( + "strconv" + "strings" +) + +// VMNameForDeployment builds a deterministic VM name from a deployment and slot index. +func VMNameForDeployment(namespace, deployment string, slot int) string { + return "vk-" + namespace + "-" + deployment + "-" + strconv.Itoa(slot) +} + +// VMNameForPod builds a deterministic VM name from a pod name. +func VMNameForPod(namespace, podName string) string { + return "vk-" + namespace + "-" + podName +} + +// AgentVMNamePrefix returns "vk-NAMESPACE-COCOONSET-", the prefix every +// agent VM name shares. +func AgentVMNamePrefix(namespace, cocoonSet string) string { + return "vk-" + namespace + "-" + cocoonSet + "-" +} + +// ExtractAgentSlot parses the trailing slot index from vmName when it +// matches the agent naming convention for (namespace, cocoonSet), or +// -1 for any toolbox VM name (e.g. "vk-NS-CS-db-2"). +func ExtractAgentSlot(namespace, cocoonSet, vmName string) int { + prefix := AgentVMNamePrefix(namespace, cocoonSet) + suffix, ok := strings.CutPrefix(vmName, prefix) + if !ok || strings.Contains(suffix, "-") { + return -1 + } + n, err := strconv.Atoi(suffix) + if err != nil || n < 0 { + return -1 + } + return n +} + +// MainAgentVMNameFor returns the VM name of the main (slot 0) agent +// for (namespace, cocoonSet). +func MainAgentVMNameFor(namespace, cocoonSet string) string { + return VMNameForDeployment(namespace, cocoonSet, 0) +} + +// InferRoleFromAgentSlot returns RoleMain for slot 0, RoleSubAgent for +// positive slots, RoleToolbox for slot < 0. +func InferRoleFromAgentSlot(slot int) string { + switch { + case slot < 0: + return RoleToolbox + case slot == 0: + return RoleMain + default: + return RoleSubAgent + } +} + +// ExtractSlotFromVMName parses the trailing slot index from a VM name, +// or -1 if absent. +// +// Deprecated: misclassifies toolbox names with numeric suffixes (e.g. +// "vk-NS-CS-db-2" → slot 2). Prefer ExtractAgentSlot. +func ExtractSlotFromVMName(vmName string) int { + _, after, ok := lastCut(vmName, "-") + if !ok { + return -1 + } + n, err := strconv.Atoi(after) + if err != nil { + return -1 + } + return n +} + +// InferRoleFromVMName returns RoleMain for slot 0, RoleSubAgent otherwise. +// +// Deprecated: shares the toolbox-collision bug of ExtractSlotFromVMName. +// Prefer InferRoleFromAgentSlot(ExtractAgentSlot(ns, cs, vmName)). +func InferRoleFromVMName(vmName string) string { + if ExtractSlotFromVMName(vmName) == 0 { + return RoleMain + } + return RoleSubAgent +} + +// lastCut is like strings.Cut but splits at the last occurrence of sep. +func lastCut(s, sep string) (before, after string, found bool) { + idx := strings.LastIndex(s, sep) + if idx < 0 { + return s, "", false + } + return s[:idx], s[idx+len(sep):], true +}