Research
February 23, 2026
15 min read

Exploring Zero-Copy mmap Loading and KV Cache Pre-Allocation for MLX on Apple Silicon

Can llama.cpp’s memory-mapped loading and pre-allocated KV caches improve MLX on Apple Silicon? We implemented both techniques and benchmarked eight Qwen3 models on M1 Max—the results reveal why MLX’s existing design is already well-suited to its lazy evaluation model.

View on GitHub
8
Models Tested
20.65x
Best mmap Speedup
<5%
KV Cache Throughput Delta
1
Bug Fixed

Abstract

llama.cpp achieves remarkably flat memory behavior on Apple Silicon through memory-mapped (mmap) weight loading and pre-allocated KV caches. We investigate whether these same techniques can benefit the MLX framework, which instead uses pread-based loading and dynamically grown KV caches.

We implement a zero-copy mmap loading path in MLX’s C++ core and a KV cache pre-allocation option in mlx-lm, and evaluate both across eight Qwen3 quantized model variants on an M1 Max (32 GB). Our results are mixed: mmap loading shows dramatic speedups for certain larger models (up to 20.65x) but performs worse for small models. KV pre-allocation flattens memory growth but adds 0.5–0.6 GB upfront cost with no throughput benefit. Overall, MLX’s existing memory management is already well-suited to its design goals.

Background

Apple Silicon’s Unified Memory Architecture (UMA) enables CPU and GPU to share the same physical memory pool with up to 800 GB/s bandwidth—eliminating PCIe bottlenecks and enabling zero-copy data sharing. This makes mmap an attractive loading strategy in theory.

MLX Approach

pread()-based loading into allocated buffers. Lazy evaluation model with dynamic KV cache growing in 256-token increments. Safetensors format.

llama.cpp Approach

mmap for zero-copy weight access. Static computation graph (ggml) with pre-allocated KV cache at full context length. GGUF format with guaranteed alignment.

The key question: would MLX benefit from adopting llama.cpp’s memory strategies? We implemented both techniques to find out.

Implementation

Zero-Copy mmap Loading

We implemented an MmapReader class in MLX’s C++ core that memory-maps safetensors files and exposes offset views via Metal buffers:

MmapReader::MmapReader(std::string file_path) {
  int fd = open(file_path.c_str(), O_RDONLY);
  struct stat st;
  fstat(fd, &st);
  file_size_ = st.st_size;
  mmap_ptr_ = mmap(nullptr, file_size_,
    PROT_READ, MAP_PRIVATE, fd, 0);
  close(fd);  // mmap survives fd closure
  madvise(mmap_ptr_, file_size_, MADV_SEQUENTIAL);
}

The critical zero-copy path creates offset views into the mmap buffer with a careful alignment check. Safetensors does not guarantee element-aligned offsets, so we verify offset % itemsize == 0 before allowing zero-copy access. Unaligned offsets fall back to memcpy.

if (reader_->is_mmap() && !swap_endianness_) {
  auto mmap_reader = dynamic_pointer_cast<MmapReader>(reader_);
  if (mmap_reader && (offset_ % out.itemsize() == 0)) {
    auto metal_buf = mmap_reader->get_metal_buffer();
    auto parent = array(metal_buf, {1}, uint8,
      [reader_ref](Buffer) { /* prevent dealloc */ });
    out.copy_shared_buffer(parent, strides, flags,
      out.size(), offset_ / out.itemsize());
    return;  // Zero-copy success
  }
}
// Fallback: allocate + memcpy for unaligned offsets

KV Cache Pre-Allocation

We added optional max_context_length parameter for upfront KV cache allocation. When specified, keys and values tensors are allocated for the full context window at startup, with mx.eval() forcing immediate physical allocation. If the sequence exceeds pre-allocated length, it gracefully falls back to dynamic expansion.

class KVCache(_BaseCache):
    step = 256
    def __init__(self, n_kv_heads=0, head_dim=0,
                 max_context_length=0, dtype=mx.float16):
        self.offset = 0
        if max_context_length > 0 and n_kv_heads > 0:
            L = ((max_context_length + 255) // 256) * 256
            self.keys = mx.zeros(
                (1, n_kv_heads, L, head_dim), dtype=dtype)
            self.values = mx.zeros(
                (1, n_kv_heads, L, head_dim), dtype=dtype)
            mx.eval(self.keys, self.values)  # Force alloc

Quantized dtype Bug Fix

During implementation, we discovered that QuantizedLinear.weight.dtype returns uint32 (the packed storage type) rather than the working precision. This caused the KV cache to be allocated in uint32 instead of float16. The fix uses scales.dtype instead, which stores dequantization factors in the correct working dtype.

Evaluation

Test Setup

  • Hardware: Apple M1 Max, 32 GB unified memory
  • Software: macOS, MLX 0.30.7, Python 3.11
  • Models: 8 Qwen3 quantized variants (4B to 14B, 3-bit to 8-bit)
  • Prompt: 34-token fixed physics question; 200 output tokens per run

Loading Speed: Standard vs. mmap

ModelStandard (s)Mmap (s)Speedup
Qwen3-4B-4bit0.1010.1310.77x
Qwen3-4B-8bit0.1840.2460.75x
Qwen3-8B-3bit0.1460.0751.95x
Qwen3-8B-4bit0.3520.3281.07x
Qwen3-8B-6bit0.6370.4351.46x
Qwen3-8B-8bit2.5720.12520.65x
Qwen3-14B-4bit2.3230.5354.34x
Qwen3-14B-6bit3.7015.3880.69x

Results are highly inconsistent. Three out of eight models load slower with mmap. The 20.65x speedup on Qwen3-8B-8bit is an outlier—likely due to the standard loader hitting a pathological pread pattern. For most models in the 2–6 GB range, MLX’s standard loading is already efficient enough.

Inference Impact

ModelGen t/s (Std)Gen t/s (Mmap)Mem Std (GB)Mem Mmap (GB)
Qwen3-4B-4bit56.958.52.382.38
Qwen3-8B-4bit30.130.14.728.82*
Qwen3-8B-8bit21.621.38.828.84
Qwen3-14B-4bit16.216.58.8411.09*
Qwen3-14B-6bit12.112.012.0812.16

* Memory nearly doubled—mmap region coexists with materialized copies due to lazy evaluation.

Key Inference Findings

  1. Generation throughput is unchanged (<3% difference across all models). Once weights are in memory, the computation is identical.
  2. Mmap can increase peak memory. For 8B-4bit and 14B-4bit, peak memory nearly doubled—MLX materializes quantized tensors during lazy evaluation while mmap-backed originals remain pinned.

KV Cache Pre-Allocation

ModelModeFTLT (ms)Gen (t/s)Peak (GB)ΔMem
Qwen3-4B-4bitDynamic260.462.02.221
Pre-alloc 2k276.263.02.471+0.250
Pre-alloc 4k285.262.32.736+0.515
Qwen3-8B-4bitDynamic419.231.34.395
Pre-alloc 2k428.030.34.633+0.238
Pre-alloc 4k436.129.74.908+0.513
Qwen3-14B-4bitDynamic859.616.77.827
Pre-alloc 2k917.416.08.101+0.274
Pre-alloc 4k901.516.18.414+0.587

Pre-allocation does not improve throughput. Generation speed stays within <5% of dynamic mode. First token latency increases 5–7% due to forced mx.eval() at startup. Memory overhead is predictable: ~0.25 GB for 2048 tokens, ~0.5 GB for 4096 tokens.

Why MLX’s Existing Design Works Well

1

pread is fast enough

For models under 4 GB (the majority of quantized models run locally), standard loading completes in under 200 ms. At this scale, mmap’s overhead—VMA creation, page table setup, TLB pressure—outweighs any benefit from avoiding copies.

2

Lazy evaluation conflicts with mmap

When quantized tensors loaded via mmap are later materialized (e.g., dequantization), MLX allocates new buffers while mmap-backed originals remain referenced. This can lead to memory doubling rather than saving—the exact opposite of the goal.

3

Dynamic KV growth is not a bottleneck

The 256-token step growth pattern does not cause meaningful overhead. Generation throughput is identical whether the cache is pre-allocated or not. MLX’s allocator handles periodic growth efficiently.

The 14B-6bit Anomaly

The largest model tested (11.18 GB) loaded slower with mmap (0.69x). At this size, the mmap approach must fault in an extremely large number of pages. The OS page fault handler becomes the bottleneck—each 16 KB page requires a kernel trap, page table update, and TLB insertion. The standard loader’s buffered reads, which batch these operations, prove more efficient at scale.

Recommendations

For most users

Stick with MLX’s default loading and dynamic KV cache—fast, memory-efficient, and well-integrated with lazy evaluation.

Default
Mmap loading

Useful as opt-in for specific large models where standard loading is slow, but should not be the default.

Opt-in
KV pre-allocation

Only for server deployments valuing memory predictability, or long-context scenarios (>8k tokens).

Niche

Conclusion

We explored whether llama.cpp’s memory management strategies—mmap loading and KV cache pre-allocation—could improve MLX’s performance on Apple Silicon. Our systematic benchmarks across eight models reveal inconsistent results: mmap helps for some larger models but hurts for small ones and can increase peak memory; KV pre-allocation achieves flat memory but offers no throughput benefit.

The conclusion: MLX’s existing memory management is already well-designed for its target use case. The design choices that differ from llama.cpp are not oversights but appropriate adaptations to MLX’s lazy evaluation model and the safetensors ecosystem. During this work, we also identified and fixed a quantized model dtype inference bug—arguably the most practically useful contribution.

Citation

@article{atomgradient2026optmlx,
  title={Exploring Zero-Copy mmap Loading and KV Cache
         Pre-Allocation for MLX on Apple Silicon},
  author={AtomGradient},
  year={2026},
  url={https://github.com/AtomGradient/OptMLX}
}