diff --git a/lib/forkvm/copy.go b/lib/forkvm/copy.go index fda4a48f..8ee9c5ff 100644 --- a/lib/forkvm/copy.go +++ b/lib/forkvm/copy.go @@ -33,12 +33,30 @@ type copyState struct { reflinkDead bool } +// CopyOptions tunes CopyGuestDirectory behavior. The zero value reproduces +// the original full-copy semantics; callers can opt into skipping specific +// paths when the consumer arranges its own substitute (e.g. a symlink to a +// template-shared mem-file). +type CopyOptions struct { + // SkipRelPaths lists relative paths under srcDir that should not be + // materialized in dstDir. Comparison is exact and uses forward-slash + // separators on all platforms. + SkipRelPaths []string +} + // CopyGuestDirectory recursively copies a guest directory to a new destination. // Regular files are cloned via reflink (FICLONE) when the underlying filesystem // supports it; otherwise we fall back to a sparse extent copy // (SEEK_DATA/SEEK_HOLE). Runtime sockets and logs are skipped because they are // host-runtime artifacts. func CopyGuestDirectory(srcDir, dstDir string) error { + return CopyGuestDirectoryWithOptions(srcDir, dstDir, CopyOptions{}) +} + +// CopyGuestDirectoryWithOptions is the option-taking variant of +// CopyGuestDirectory. Use this when forking with template-shared assets, so +// the caller can install a symlink in place of a heavy copied file. +func CopyGuestDirectoryWithOptions(srcDir, dstDir string, opts CopyOptions) error { srcInfo, err := os.Stat(srcDir) if err != nil { return fmt.Errorf("stat source directory: %w", err) @@ -56,6 +74,11 @@ func CopyGuestDirectory(srcDir, dstDir string) error { state.reflinkDead = true } + skipSet := make(map[string]struct{}, len(opts.SkipRelPaths)) + for _, p := range opts.SkipRelPaths { + skipSet[filepath.ToSlash(p)] = struct{}{} + } + return filepath.WalkDir(srcDir, func(path string, d fs.DirEntry, walkErr error) error { if walkErr != nil { return walkErr @@ -68,6 +91,12 @@ func CopyGuestDirectory(srcDir, dstDir string) error { if relPath == "." { return nil } + if _, skip := skipSet[filepath.ToSlash(relPath)]; skip { + if d.IsDir() { + return filepath.SkipDir + } + return nil + } if d.IsDir() && shouldSkipDirectory(relPath) { return filepath.SkipDir } diff --git a/lib/forkvm/copy_test.go b/lib/forkvm/copy_test.go index c71f6c4e..56fb6caf 100644 --- a/lib/forkvm/copy_test.go +++ b/lib/forkvm/copy_test.go @@ -44,6 +44,25 @@ func TestCopyGuestDirectory(t *testing.T) { assert.Equal(t, "metadata.json", linkTarget) } +func TestCopyGuestDirectory_SkipRelPaths(t *testing.T) { + src := filepath.Join(t.TempDir(), "src") + dst := filepath.Join(t.TempDir(), "dst") + + require.NoError(t, os.MkdirAll(filepath.Join(src, "snapshots", "snapshot-latest"), 0755)) + require.NoError(t, os.WriteFile(filepath.Join(src, "snapshots", "snapshot-latest", "config.json"), []byte(`{}`), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(src, "snapshots", "snapshot-latest", "memory"), []byte("the heavy mem-file"), 0644)) + require.NoError(t, os.WriteFile(filepath.Join(src, "snapshots", "snapshot-latest", "state"), []byte("device state"), 0644)) + + err := CopyGuestDirectoryWithOptions(src, dst, CopyOptions{ + SkipRelPaths: []string{"snapshots/snapshot-latest/memory"}, + }) + require.NoError(t, err) + + assert.NoFileExists(t, filepath.Join(dst, "snapshots", "snapshot-latest", "memory")) + assert.FileExists(t, filepath.Join(dst, "snapshots", "snapshot-latest", "config.json")) + assert.FileExists(t, filepath.Join(dst, "snapshots", "snapshot-latest", "state")) +} + func TestCopyGuestDirectory_DoesNotSkipTmpSuffixedDirectories(t *testing.T) { src := filepath.Join(t.TempDir(), "src") dst := filepath.Join(t.TempDir(), "dst") diff --git a/lib/instances/fork.go b/lib/instances/fork.go index 1bb4102f..e358048f 100644 --- a/lib/instances/fork.go +++ b/lib/instances/fork.go @@ -255,19 +255,39 @@ func (m *manager) forkInstanceFromStoppedOrStandby(ctx context.Context, id strin fromSnapshot := source.State == StateStandby || source.State == StateTemplate + // shareMemFile gates mem-file fan-out from the source's standby snapshot. + // Firecracker only: it mmaps the snapshot mem-file MAP_PRIVATE on restore, + // so all forks safely COW from the same backing file. Cloud-hypervisor and + // other hypervisors take a copy-mode path and don't benefit. Restricted to + // Template sources because they are explicitly promoted as fork-only and + // can never be restored — sharing the mem-file with a non-Template source + // would let a later RestoreInstance mutate the file out from under live + // forks. + shareMemFile := source.State == StateTemplate && stored.HypervisorType == hypervisor.TypeFirecracker + if fromSnapshot { if err := m.ensureSnapshotMemoryReady(ctx, m.paths.InstanceSnapshotLatest(id), m.snapshotJobKeyForInstance(id), stored.HypervisorType); err != nil { return nil, fmt.Errorf("prepare standby snapshot for fork: %w", err) } } - if err := forkvm.CopyGuestDirectory(srcDir, dstDir); err != nil { + copyOpts := forkvm.CopyOptions{} + if shareMemFile { + copyOpts.SkipRelPaths = []string{templateSharedMemFileRelPath} + } + if err := forkvm.CopyGuestDirectoryWithOptions(srcDir, dstDir, copyOpts); err != nil { if errors.Is(err, forkvm.ErrSparseCopyUnsupported) { return nil, fmt.Errorf("fork requires sparse-capable filesystem (SEEK_DATA/SEEK_HOLE unsupported): %w", err) } return nil, fmt.Errorf("clone guest directory: %w", err) } + if shareMemFile { + if err := m.installForkSharedMemFile(dstDir, id); err != nil { + return nil, fmt.Errorf("install shared mem-file: %w", err) + } + } + starter, err := m.getVMStarter(stored.HypervisorType) if err != nil { return nil, fmt.Errorf("get vm starter: %w", err) diff --git a/lib/instances/templates.go b/lib/instances/templates.go new file mode 100644 index 00000000..b40e5859 --- /dev/null +++ b/lib/instances/templates.go @@ -0,0 +1,46 @@ +package instances + +import ( + "fmt" + "os" + "path/filepath" +) + +const ( + templateSharedMemFileName = "memory" + templateSharedMemFileRelPath = "snapshots/snapshot-latest/memory" +) + +// installForkSharedMemFile arranges the fork's snapshot directory so the +// guest mem-file is a hardlink to the source template instance's snapshot +// mem-file instead of a per-fork copy. firecracker mmaps the mem-file +// MAP_PRIVATE during restore, so all forks COW from the same backing inode. +// +// Layout: forkDataDir is the fork's data dir. The snapshot dir is at +// /snapshots/snapshot-latest, and the mem-file lives at +// /memory. The hardlink shares the inode with the source +// instance's standby snapshot mem-file. +// +// We use a hardlink rather than a symlink because firecracker's restore +// path temporarily aliases the source data dir to the fork data dir while +// it loads the snapshot (see withSnapshotSourceDirAlias). A symlink whose +// target traverses the source dir would resolve back into the fork dir +// during that window and trip ELOOP; a hardlink resolves by inode so the +// alias has no effect on it. Hardlinks require both paths on the same +// filesystem, which holds for our standard data-dir layout. +func (m *manager) installForkSharedMemFile(forkDataDir, sourceInstanceID string) error { + srcMem := filepath.Join(m.paths.InstanceSnapshotLatest(sourceInstanceID), templateSharedMemFileName) + if _, err := os.Stat(srcMem); err != nil { + return fmt.Errorf("stat template mem-file: %w", err) + } + dstSnapshotDir := filepath.Join(forkDataDir, "snapshots", "snapshot-latest") + if err := os.MkdirAll(dstSnapshotDir, 0o755); err != nil { + return fmt.Errorf("ensure fork snapshot dir: %w", err) + } + dstMem := filepath.Join(dstSnapshotDir, templateSharedMemFileName) + _ = os.Remove(dstMem) + if err := os.Link(srcMem, dstMem); err != nil { + return fmt.Errorf("hardlink shared mem-file: %w", err) + } + return nil +} diff --git a/lib/instances/templates_shared_memfile_linux_test.go b/lib/instances/templates_shared_memfile_linux_test.go new file mode 100644 index 00000000..49ccf98e --- /dev/null +++ b/lib/instances/templates_shared_memfile_linux_test.go @@ -0,0 +1,84 @@ +//go:build linux + +package instances + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/kernel/hypeman/lib/hypervisor" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestForkFirecrackerSharesMemFile_FromTemplate verifies the end-to-end fork +// path: when the source is a Firecracker Template instance, the fork's +// mem-file is a hardlink to the source's mem-file instead of a copy. This +// preserves the firecracker MAP_PRIVATE COW semantics that let multiple forks +// share the heavy backing file. +func TestForkFirecrackerSharesMemFile_FromTemplate(t *testing.T) { + t.Parallel() + + mgr, _ := setupTestManager(t) + ctx := context.Background() + + sourceID := "shared-memfile-fc-src" + createStandbySnapshotSourceFixture(t, mgr, sourceID, "shared-memfile-fc-src", hypervisor.TypeFirecracker) + promoteFixtureToTemplate(t, mgr, sourceID) + + srcSnapshotDir := mgr.paths.InstanceSnapshotLatest(sourceID) + srcMem := filepath.Join(srcSnapshotDir, templateSharedMemFileName) + require.NoError(t, os.WriteFile(srcMem, []byte("firecracker mem-file contents"), 0o644)) + snapshotConfigPath := mgr.paths.InstanceSnapshotConfig(sourceID) + require.NoError(t, os.MkdirAll(filepath.Dir(snapshotConfigPath), 0o755)) + require.NoError(t, os.WriteFile(snapshotConfigPath, []byte(`{}`), 0o644)) + + forked, err := mgr.forkInstanceFromStoppedOrStandby(ctx, sourceID, ForkInstanceRequest{ + Name: "shared-memfile-fc-fork", + TargetState: StateStopped, + }, true) + require.NoError(t, err) + require.NotNil(t, forked) + + forkMem := filepath.Join(mgr.paths.InstanceSnapshotLatest(forked.Id), templateSharedMemFileName) + info, err := os.Lstat(forkMem) + require.NoError(t, err) + assert.True(t, info.Mode().IsRegular(), "fork mem-file must be a regular file (hardlink) for firecracker fan-out") + assert.True(t, sameInode(t, srcMem, forkMem), "fork mem-file must share the source's inode") +} + +// TestForkFirecrackerStandbySourceDoesNotShareMemFile guards the +// non-Template carve-out: forking a plain Standby source must copy the +// mem-file outright. Sharing would let a later RestoreInstance on the source +// mutate the file out from under live forks. +func TestForkFirecrackerStandbySourceDoesNotShareMemFile(t *testing.T) { + t.Parallel() + + mgr, _ := setupTestManager(t) + ctx := context.Background() + + sourceID := "standby-fork-fc-src" + createStandbySnapshotSourceFixture(t, mgr, sourceID, "standby-fork-fc-src", hypervisor.TypeFirecracker) + + srcSnapshotDir := mgr.paths.InstanceSnapshotLatest(sourceID) + srcMem := filepath.Join(srcSnapshotDir, templateSharedMemFileName) + require.NoError(t, os.WriteFile(srcMem, []byte("firecracker mem-file contents"), 0o644)) + snapshotConfigPath := mgr.paths.InstanceSnapshotConfig(sourceID) + require.NoError(t, os.MkdirAll(filepath.Dir(snapshotConfigPath), 0o755)) + require.NoError(t, os.WriteFile(snapshotConfigPath, []byte(`{}`), 0o644)) + + forked, err := mgr.forkInstanceFromStoppedOrStandby(ctx, sourceID, ForkInstanceRequest{ + Name: "standby-fork-fc-fork", + TargetState: StateStopped, + }, true) + require.NoError(t, err) + require.NotNil(t, forked) + + forkMem := filepath.Join(mgr.paths.InstanceSnapshotLatest(forked.Id), templateSharedMemFileName) + info, err := os.Lstat(forkMem) + require.NoError(t, err) + require.True(t, info.Mode().IsRegular(), "standby-source fork mem-file must be a regular file copy") + assert.False(t, sameInode(t, srcMem, forkMem), "standby-source fork mem-file must be a copy, not a hardlink to source") +} diff --git a/lib/instances/templates_shared_memfile_test.go b/lib/instances/templates_shared_memfile_test.go new file mode 100644 index 00000000..f221a0d3 --- /dev/null +++ b/lib/instances/templates_shared_memfile_test.go @@ -0,0 +1,71 @@ +package instances + +import ( + "os" + "path/filepath" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func sameInode(t *testing.T, a, b string) bool { + t.Helper() + ai, err := os.Stat(a) + require.NoError(t, err) + bi, err := os.Stat(b) + require.NoError(t, err) + as := ai.Sys().(*syscall.Stat_t) + bs := bi.Sys().(*syscall.Stat_t) + return as.Ino == bs.Ino && as.Dev == bs.Dev +} + +// TestInstallForkSharedMemFile_HardlinksSourceMemFile verifies that the helper +// creates a hardlink at the fork's snapshot mem-file path that shares the +// source instance's mem-file inode. +func TestInstallForkSharedMemFile_HardlinksSourceMemFile(t *testing.T) { + t.Parallel() + + mgr, _ := newStorageOnlyManager(t) + sourceID := "shared-memfile-source" + + srcSnapshotDir := mgr.paths.InstanceSnapshotLatest(sourceID) + require.NoError(t, os.MkdirAll(srcSnapshotDir, 0o755)) + srcMem := filepath.Join(srcSnapshotDir, templateSharedMemFileName) + require.NoError(t, os.WriteFile(srcMem, []byte("guest memory bytes"), 0o644)) + + forkDir := filepath.Join(t.TempDir(), "fork-data") + + require.NoError(t, mgr.installForkSharedMemFile(forkDir, sourceID)) + + forkMem := filepath.Join(forkDir, "snapshots", "snapshot-latest", templateSharedMemFileName) + info, err := os.Lstat(forkMem) + require.NoError(t, err) + assert.True(t, info.Mode().IsRegular(), "fork mem-file must be a regular file (hardlink), not a symlink") + assert.True(t, sameInode(t, srcMem, forkMem), "fork mem-file must share the source's inode") +} + +// TestInstallForkSharedMemFile_ErrorsWhenSourceMissing makes sure the helper +// refuses to silently create a dangling link when the source mem-file does not +// exist. +func TestInstallForkSharedMemFile_ErrorsWhenSourceMissing(t *testing.T) { + t.Parallel() + + mgr, _ := newStorageOnlyManager(t) + forkDir := filepath.Join(t.TempDir(), "fork-data") + + err := mgr.installForkSharedMemFile(forkDir, "no-such-source") + require.Error(t, err) +} + +// promoteFixtureToTemplate marks the source's stored metadata as a Template +// without invoking the full PromoteToTemplate lifecycle (which would require +// a live VM). Test-only shortcut. +func promoteFixtureToTemplate(t *testing.T, mgr *manager, id string) { + t.Helper() + meta, err := mgr.loadMetadata(id) + require.NoError(t, err) + meta.IsTemplate = true + require.NoError(t, mgr.saveMetadata(meta)) +} diff --git a/lib/uffd/server_linux.go b/lib/uffd/server_linux.go new file mode 100644 index 00000000..616ad3ab --- /dev/null +++ b/lib/uffd/server_linux.go @@ -0,0 +1,308 @@ +//go:build linux + +package uffd + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "os" + "sync" + "syscall" + "unsafe" + + "golang.org/x/sys/unix" +) + +// userfaultfd ioctl numbers and feature flags. The constants are derived +// from : _IOWR(0xAA, ...) with the size of each +// argument struct in bits 16–29. +const ( + uffdAPI = 0xAA + uffdAPIFeature = 0x0 // we only need missing-page faults; no extra features. + + uffdioAPI = 0xC018AA3F // _IOWR(0xAA, 0x3F, struct uffdio_api{24}) + uffdioRegister = 0xC020AA00 // _IOWR(0xAA, 0x00, struct uffdio_register{32}) + uffdioCopyIoctl = 0xC028AA03 // _IOWR(0xAA, 0x03, struct uffdio_copy{40}) + uffdioZeropage = 0xC020AA04 // _IOWR(0xAA, 0x04, struct uffdio_zeropage{32}) + uffdRegMissing = 1 << 0 + uffdEventPagefnt = 0x12 // UFFD_EVENT_PAGEFAULT +) + +// uffdMsg mirrors struct uffd_msg from . It is a +// 32-byte fixed-size record; we only consume the pagefault arm. +type uffdMsg struct { + Event uint8 + _ uint8 + _ uint16 + _ uint32 + Pagefault struct { + Flags uint64 + Address uint64 + Ptid uint32 + _ uint32 + } +} + +// uffdioAPIArg is struct uffdio_api. +type uffdioAPIArg struct { + API uint64 + Features uint64 + Ioctls uint64 +} + +// uffdioRegisterArg is struct uffdio_register. +type uffdioRegisterArg struct { + Start uint64 + Len uint64 + Mode uint64 + Ioctls uint64 +} + +// uffdioCopyArg is struct uffdio_copy. +type uffdioCopyArg struct { + Dst uint64 + Src uint64 + Len uint64 + Mode uint64 + Copy int64 +} + +// startListener opens the per-fork UDS, accepts firecracker's connection, +// receives the userfaultfd via SCM_RIGHTS plus the JSON handshake, and +// then runs the page-fault loop. The returned closer stops accept, +// signals the handler, and removes the socket file. +func (s *Server) startListener(ctx context.Context, forkID string, socketPath string) (func() error, error) { + // Remove any stale socket file from a prior run; UDS bind fails otherwise. + _ = os.Remove(socketPath) + ln, err := net.Listen("unix", socketPath) + if err != nil { + return nil, fmt.Errorf("uffd: listen %s: %w", socketPath, err) + } + + hctx, hcancel := context.WithCancel(ctx) + + var ( + wg sync.WaitGroup + mu sync.Mutex + uffdFd int = -1 + closed bool + ) + + closer := func() error { + mu.Lock() + if closed { + mu.Unlock() + wg.Wait() + return nil + } + closed = true + fd := uffdFd + uffdFd = -1 + mu.Unlock() + + hcancel() + _ = ln.Close() + if fd >= 0 { + _ = unix.Close(fd) + } + wg.Wait() + _ = os.Remove(socketPath) + return nil + } + + wg.Add(1) + go func() { + defer wg.Done() + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + + fd, regions, err := receiveHandshake(conn) + if err != nil { + return + } + mu.Lock() + if closed { + mu.Unlock() + _ = unix.Close(fd) + return + } + uffdFd = fd + mu.Unlock() + + if err := uffdAPIHandshake(fd); err != nil { + return + } + for _, r := range regions { + if err := uffdRegisterRegion(fd, r); err != nil { + return + } + } + + s.servePageFaults(hctx, fd, regions, forkID) + }() + + return closer, nil +} + +// receiveHandshake reads firecracker's JSON payload and the userfaultfd +// over a single recvmsg(2) call. Firecracker sends them together; if the +// kernel splits them across reads we loop until the fd arrives. +func receiveHandshake(conn net.Conn) (int, []MemoryRegion, error) { + uc, ok := conn.(*net.UnixConn) + if !ok { + return -1, nil, errors.New("uffd: connection is not a unix socket") + } + f, err := uc.File() + if err != nil { + return -1, nil, fmt.Errorf("uffd: get fd from unix conn: %w", err) + } + defer f.Close() + + // Read until we have the SCM_RIGHTS fd. The JSON body is small, so + // a 4 KiB buffer plus one OOB control message is plenty. + buf := make([]byte, 4096) + oob := make([]byte, unix.CmsgSpace(4)) + var ( + jsonBytes []byte + fd int = -1 + ) + for fd < 0 { + n, oobn, _, _, err := unix.Recvmsg(int(f.Fd()), buf, oob, 0) + if err != nil { + return -1, nil, fmt.Errorf("uffd: recvmsg: %w", err) + } + if n > 0 { + jsonBytes = append(jsonBytes, buf[:n]...) + } + if oobn > 0 { + scms, perr := unix.ParseSocketControlMessage(oob[:oobn]) + if perr != nil { + return -1, nil, fmt.Errorf("uffd: parse cmsg: %w", perr) + } + for _, scm := range scms { + fds, ferr := unix.ParseUnixRights(&scm) + if ferr != nil { + return -1, nil, fmt.Errorf("uffd: parse fds: %w", ferr) + } + if len(fds) > 0 { + fd = fds[0] + for _, extra := range fds[1:] { + _ = unix.Close(extra) + } + } + } + } + if n == 0 && oobn == 0 { + return -1, nil, io.ErrUnexpectedEOF + } + } + + hs, err := parseHandshake(jsonBytes) + if err != nil { + _ = unix.Close(fd) + return -1, nil, err + } + return fd, hs.Mappings, nil +} + +func uffdAPIHandshake(fd int) error { + api := uffdioAPIArg{API: uffdAPI, Features: uffdAPIFeature} + if err := ioctl(fd, uffdioAPI, unsafe.Pointer(&api)); err != nil { + return fmt.Errorf("uffd: UFFDIO_API: %w", err) + } + return nil +} + +func uffdRegisterRegion(fd int, r MemoryRegion) error { + reg := uffdioRegisterArg{ + Start: uint64(r.BaseHostAddr), + Len: r.Size, + Mode: uffdRegMissing, + } + if err := ioctl(fd, uffdioRegister, unsafe.Pointer(®)); err != nil { + return fmt.Errorf("uffd: UFFDIO_REGISTER: %w", err) + } + return nil +} + +// servePageFaults blocks reading uffd events on fd. For each +// UFFD_EVENT_PAGEFAULT we look up the region containing the faulting +// address, read a page from the template mem-file, and call UFFDIO_COPY +// to satisfy the fault. +func (s *Server) servePageFaults(ctx context.Context, fd int, regions []MemoryRegion, forkID string) { + page := make([]byte, s.pageSize) + var msg uffdMsg + msgSize := int(unsafe.Sizeof(msg)) + rawBuf := make([]byte, msgSize) + + for { + if ctx.Err() != nil { + return + } + n, err := unix.Read(fd, rawBuf) + if err != nil { + if errors.Is(err, syscall.EINTR) { + continue + } + return + } + if n != msgSize { + return + } + event := rawBuf[0] + if event != uffdEventPagefnt { + continue + } + // pagefault.address starts at offset 16 of uffd_msg. + addr := binary.LittleEndian.Uint64(rawBuf[16:24]) + if err := s.copyPageForFault(fd, regions, addr, page); err != nil { + return + } + } +} + +func (s *Server) copyPageForFault(fd int, regions []MemoryRegion, addr uint64, page []byte) error { + pageSize := uint64(s.pageSize) + pageStart := addr &^ (pageSize - 1) + + for _, r := range regions { + base := uint64(r.BaseHostAddr) + if pageStart < base || pageStart >= base+r.Size { + continue + } + offset := int64(r.MemFileOffset + (pageStart - base)) + if _, err := s.memFile.ReadAt(page, offset); err != nil && !errors.Is(err, io.EOF) { + return fmt.Errorf("uffd: read template at %d: %w", offset, err) + } + copyArg := uffdioCopyArg{ + Dst: pageStart, + Src: uint64(uintptr(unsafe.Pointer(&page[0]))), + Len: pageSize, + } + if err := ioctl(fd, uffdioCopyIoctl, unsafe.Pointer(©Arg)); err != nil { + // Spurious/duplicate faults can race other vCPUs; treat + // them as benign and keep serving. + if errors.Is(err, syscall.EEXIST) || errors.Is(err, syscall.EAGAIN) { + return nil + } + return fmt.Errorf("uffd: UFFDIO_COPY: %w", err) + } + return nil + } + return fmt.Errorf("uffd: fault addr 0x%x outside any registered region", addr) +} + +func ioctl(fd int, req uintptr, arg unsafe.Pointer) error { + _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), req, uintptr(arg)) + if errno != 0 { + return errno + } + return nil +} diff --git a/lib/uffd/server_other.go b/lib/uffd/server_other.go new file mode 100644 index 00000000..7ce7b287 --- /dev/null +++ b/lib/uffd/server_other.go @@ -0,0 +1,15 @@ +//go:build !linux + +package uffd + +import "context" + +// startListener returns ErrUnsupported on non-Linux platforms. +// userfaultfd is a Linux-only kernel feature; callers should fall back +// to letting firecracker mmap the mem-file privately. +func (s *Server) startListener(ctx context.Context, forkID string, socketPath string) (func() error, error) { + _ = ctx + _ = forkID + _ = socketPath + return nil, ErrUnsupported +} diff --git a/lib/uffd/uffd.go b/lib/uffd/uffd.go new file mode 100644 index 00000000..9531f02b --- /dev/null +++ b/lib/uffd/uffd.go @@ -0,0 +1,253 @@ +// Package uffd implements a userfaultfd page server for firecracker +// snapshot fan-out. The server backs many concurrent forks against a +// single read-only template mem-file: instead of letting firecracker +// mmap the mem-file privately per fork (which forces every page to be +// copied on first touch), firecracker is configured to use a +// userfaultfd memory backend, and this server populates pages on +// demand from the template file. +// +// One Server instance handles one template mem-file and any number of +// fork connections. Each fork's firecracker process connects to a +// per-fork UDS and hands the server its userfaultfd via SCM_RIGHTS +// alongside a JSON payload describing the guest memory mappings; the +// server then handles UFFDIO_COPY for every faulted page. +// +// The protocol (firecracker_uffd_protocol below) is the contract +// firecracker speaks; we keep it isolated here so PR 8 can ride on +// top to prefetch hot pages without touching firecracker glue code. +// +// PR 5 ships the server skeleton, the protocol parser, and a unit +// test surface that doesn't require KVM. The hot-path syscalls live +// in server_linux.go behind a build tag because userfaultfd is a +// Linux-only kernel feature. +package uffd + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "sync" +) + +// ErrUnsupported is returned on platforms where userfaultfd is not +// available. Callers should treat this as "fall back to mmap MAP_PRIVATE." +var ErrUnsupported = errors.New("userfaultfd unsupported on this platform") + +// MemoryRegion describes a contiguous region of guest physical memory +// that maps into a [BaseHostAddr, BaseHostAddr+Size) virtual range in +// the firecracker process. The server services UFFDIO_COPY into that +// range using bytes from MemFileOffset. +type MemoryRegion struct { + BaseHostAddr uintptr `json:"base_host_virt_addr"` + Size uint64 `json:"size"` + MemFileOffset uint64 `json:"offset"` +} + +// firecrackerHandshake is the JSON payload firecracker sends on its +// UDS connection right before it passes the userfaultfd via SCM_RIGHTS. +// We only use the fields we care about for serving page faults; the +// rest of firecracker's payload is ignored. +type firecrackerHandshake struct { + Mappings []MemoryRegion `json:"mappings"` +} + +// Config configures a Server. +type Config struct { + // MemFilePath is the path to the template mem-file. The server + // opens it read-only and serves pages from it. + MemFilePath string + + // SocketDir is where per-fork UDS files live. The directory must + // exist and be writable by the server. One UDS is created per + // RegisterFork call. + SocketDir string + + // PageSize is the target page size for UFFDIO_COPY. Must be a + // multiple of os.Getpagesize. Zero means use the host page size. + PageSize int +} + +// Server owns the template mem-file and dispatches userfaultfd events +// for every connected fork. It is safe for concurrent use; methods may +// be called from any goroutine. +type Server struct { + cfg Config + memFile *os.File + memSize int64 + + mu sync.Mutex + listens map[string]*forkListen // forkID -> per-fork bookkeeping + closed bool + pageSize int +} + +type forkListen struct { + socketPath string + closer func() error +} + +// NewServer opens the template mem-file and prepares the server. It +// does not start any goroutines yet; callers register forks one by one. +// When the server is closed, the mem-file fd is released; in-flight +// fork handlers are signaled to exit and joined. +func NewServer(cfg Config) (*Server, error) { + if cfg.MemFilePath == "" { + return nil, errors.New("uffd: MemFilePath is required") + } + if cfg.SocketDir == "" { + return nil, errors.New("uffd: SocketDir is required") + } + if err := os.MkdirAll(cfg.SocketDir, 0o755); err != nil { + return nil, fmt.Errorf("uffd: ensure socket dir: %w", err) + } + f, err := os.Open(cfg.MemFilePath) + if err != nil { + return nil, fmt.Errorf("uffd: open mem-file: %w", err) + } + st, err := f.Stat() + if err != nil { + _ = f.Close() + return nil, fmt.Errorf("uffd: stat mem-file: %w", err) + } + pageSize := cfg.PageSize + if pageSize == 0 { + pageSize = os.Getpagesize() + } + return &Server{ + cfg: cfg, + memFile: f, + memSize: st.Size(), + listens: map[string]*forkListen{}, + pageSize: pageSize, + }, nil +} + +// SocketPath returns the UDS path that should be passed to firecracker +// for a fork. RegisterFork must be called first. +func (s *Server) SocketPath(forkID string) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return "", errors.New("uffd: server closed") + } + listen, ok := s.listens[forkID] + if !ok { + return "", fmt.Errorf("uffd: fork %q is not registered", forkID) + } + return listen.socketPath, nil +} + +// MemSize returns the size of the template mem-file in bytes. Useful +// for sizing prefetch buffers and validating handshake mappings. +func (s *Server) MemSize() int64 { return s.memSize } + +// PageSize returns the configured page size in bytes. +func (s *Server) PageSize() int { return s.pageSize } + +// Close stops the server, closes all per-fork listeners, and releases +// the template mem-file fd. After Close returns, the server cannot be +// reused. +func (s *Server) Close() error { + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return nil + } + s.closed = true + listens := s.listens + s.listens = nil + s.mu.Unlock() + + var firstErr error + for _, l := range listens { + if l.closer != nil { + if err := l.closer(); err != nil && firstErr == nil { + firstErr = err + } + } + } + if err := s.memFile.Close(); err != nil && firstErr == nil { + firstErr = err + } + return firstErr +} + +// parseHandshake decodes firecracker's JSON handshake payload. Exposed +// so tests can validate the parser without spinning up a real socket. +func parseHandshake(data []byte) (firecrackerHandshake, error) { + var h firecrackerHandshake + if err := json.Unmarshal(data, &h); err != nil { + return firecrackerHandshake{}, fmt.Errorf("uffd: parse handshake: %w", err) + } + if len(h.Mappings) == 0 { + return firecrackerHandshake{}, errors.New("uffd: handshake has no mappings") + } + return h, nil +} + +// resolveSocketPath returns the per-fork socket path. The server uses +// short names because Unix domain sockets have a tight sun_path limit; +// callers should keep SocketDir short. +func (s *Server) resolveSocketPath(forkID string) string { + return filepath.Join(s.cfg.SocketDir, forkID+".uffd") +} + +// RegisterFork allocates a per-fork listener and waits asynchronously +// for firecracker to connect. The returned context cancels when the +// server closes or the fork unregisters. +// +// On Linux the heavy lifting (accept, recvmsg, ioctl loop) lives in +// server_linux.go; on other platforms RegisterFork returns ErrUnsupported. +func (s *Server) RegisterFork(ctx context.Context, forkID string) (string, error) { + if forkID == "" { + return "", errors.New("uffd: fork id is required") + } + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return "", errors.New("uffd: server closed") + } + if _, dup := s.listens[forkID]; dup { + s.mu.Unlock() + return "", fmt.Errorf("uffd: fork %q already registered", forkID) + } + socketPath := s.resolveSocketPath(forkID) + s.mu.Unlock() + + closer, err := s.startListener(ctx, forkID, socketPath) + if err != nil { + return "", err + } + + s.mu.Lock() + if s.closed { + s.mu.Unlock() + _ = closer() + return "", errors.New("uffd: server closed during register") + } + s.listens[forkID] = &forkListen{socketPath: socketPath, closer: closer} + s.mu.Unlock() + + return socketPath, nil +} + +// UnregisterFork closes the listener for forkID. Called when the fork +// is destroyed; the server stops servicing its faults and removes the +// UDS file. +func (s *Server) UnregisterFork(forkID string) error { + s.mu.Lock() + listen, ok := s.listens[forkID] + if !ok { + s.mu.Unlock() + return nil + } + delete(s.listens, forkID) + s.mu.Unlock() + if listen.closer != nil { + return listen.closer() + } + return nil +} diff --git a/lib/uffd/uffd_test.go b/lib/uffd/uffd_test.go new file mode 100644 index 00000000..c6b2fc3d --- /dev/null +++ b/lib/uffd/uffd_test.go @@ -0,0 +1,127 @@ +package uffd + +import ( + "errors" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func writeTempMemFile(t *testing.T, size int) string { + t.Helper() + path := filepath.Join(t.TempDir(), "memory") + f, err := os.Create(path) + require.NoError(t, err) + require.NoError(t, f.Truncate(int64(size))) + require.NoError(t, f.Close()) + return path +} + +func TestNewServer_RequiresMemFile(t *testing.T) { + _, err := NewServer(Config{SocketDir: t.TempDir()}) + assert.Error(t, err) +} + +func TestNewServer_RequiresSocketDir(t *testing.T) { + _, err := NewServer(Config{MemFilePath: writeTempMemFile(t, 4096)}) + assert.Error(t, err) +} + +func TestNewServer_ReportsMemSizeAndPageSize(t *testing.T) { + memPath := writeTempMemFile(t, 16384) + s, err := NewServer(Config{ + MemFilePath: memPath, + SocketDir: t.TempDir(), + PageSize: 4096, + }) + require.NoError(t, err) + defer s.Close() + + assert.Equal(t, int64(16384), s.MemSize()) + assert.Equal(t, 4096, s.PageSize()) +} + +func TestNewServer_DefaultsPageSizeToHost(t *testing.T) { + s, err := NewServer(Config{ + MemFilePath: writeTempMemFile(t, 4096), + SocketDir: t.TempDir(), + }) + require.NoError(t, err) + defer s.Close() + + assert.Equal(t, os.Getpagesize(), s.PageSize()) +} + +func TestSocketPath_UnregisteredFork(t *testing.T) { + s, err := NewServer(Config{ + MemFilePath: writeTempMemFile(t, 4096), + SocketDir: t.TempDir(), + }) + require.NoError(t, err) + defer s.Close() + + _, err = s.SocketPath("missing") + assert.Error(t, err) +} + +func TestUnregisterFork_UnknownIsNoop(t *testing.T) { + s, err := NewServer(Config{ + MemFilePath: writeTempMemFile(t, 4096), + SocketDir: t.TempDir(), + }) + require.NoError(t, err) + defer s.Close() + + assert.NoError(t, s.UnregisterFork("does-not-exist")) +} + +func TestClose_Idempotent(t *testing.T) { + s, err := NewServer(Config{ + MemFilePath: writeTempMemFile(t, 4096), + SocketDir: t.TempDir(), + }) + require.NoError(t, err) + require.NoError(t, s.Close()) + assert.NoError(t, s.Close()) +} + +func TestParseHandshake_GoodPayload(t *testing.T) { + data := []byte(`{"mappings":[{"base_host_virt_addr":4096,"size":8192,"offset":0}]}`) + hs, err := parseHandshake(data) + require.NoError(t, err) + require.Len(t, hs.Mappings, 1) + assert.Equal(t, uintptr(4096), hs.Mappings[0].BaseHostAddr) + assert.Equal(t, uint64(8192), hs.Mappings[0].Size) + assert.Equal(t, uint64(0), hs.Mappings[0].MemFileOffset) +} + +func TestParseHandshake_RejectsEmptyMappings(t *testing.T) { + _, err := parseHandshake([]byte(`{"mappings":[]}`)) + assert.Error(t, err) +} + +func TestParseHandshake_RejectsBadJSON(t *testing.T) { + _, err := parseHandshake([]byte(`{not json`)) + assert.Error(t, err) +} + +func TestResolveSocketPath_PerFork(t *testing.T) { + dir := t.TempDir() + s, err := NewServer(Config{ + MemFilePath: writeTempMemFile(t, 4096), + SocketDir: dir, + }) + require.NoError(t, err) + defer s.Close() + + got := s.resolveSocketPath("fork-1") + assert.Equal(t, filepath.Join(dir, "fork-1.uffd"), got) +} + +func TestErrUnsupportedSentinel(t *testing.T) { + // The sentinel must be a stable error value so callers can switch on it. + assert.True(t, errors.Is(ErrUnsupported, ErrUnsupported)) +}