Documentation Index
Fetch the complete documentation index at: https://mintlify.com/karpathy/nanochat/llms.txt
Use this file to discover all available pages before exploring further.
The Engine class provides efficient inference for NanoChat models using KV caching and supports advanced features like tool use and multi-sample generation.
Overview
The engine is designed for maximum efficiency:
- KV Cache: Stores key-value pairs from previous tokens to avoid recomputation
- Streaming Generation: Yields tokens one at a time for real-time output
- Batch Generation: Generate multiple samples in parallel
- Tool Use: Built-in calculator tool with automatic result injection
Basic Usage
from nanochat.engine import Engine
from nanochat.checkpoint_manager import load_model
# Load model and tokenizer
model, tokenizer, meta = load_model("sft", device, phase="eval")
# Create engine
engine = Engine(model, tokenizer)
# Generate tokens
prompt_tokens = tokenizer.encode("What is 2+2?", prepend=bos_token_id)
for token_column, token_masks in engine.generate(prompt_tokens, num_samples=1, max_tokens=100):
token = token_column[0]
print(tokenizer.decode([token]), end="", flush=True)
Generation Methods
Streaming Generation
generate(tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42)
Streaming generator that yields tokens one at a time.
Parameters:
tokens (list[int]): Input token sequence
num_samples (int): Number of parallel samples to generate (default: 1)
max_tokens (int): Maximum tokens to generate (default: None = unlimited)
temperature (float): Sampling temperature, 0.0 = greedy (default: 1.0)
top_k (int): Top-k sampling parameter (default: None)
seed (int): Random seed (default: 42)
Yields:
token_column (list[int]): Next token for each sample (length = num_samples)
token_masks (list[int]): 1 if sampled, 0 if forced by tool (length = num_samples)
Example:
for token_column, token_masks in engine.generate(
prompt_tokens,
num_samples=4, # Generate 4 samples in parallel
max_tokens=256,
temperature=0.8,
top_k=50,
seed=12345
):
for i, (token, mask) in enumerate(zip(token_column, token_masks)):
if mask == 1:
print(f"Sample {i}: {tokenizer.decode([token])}")
else:
print(f"Sample {i}: [FORCED] {tokenizer.decode([token])}")
Batch Generation
generate_batch(tokens, num_samples=1, **kwargs)
Non-streaming batch generation that returns complete token sequences.
Returns:
results (list[list[int]]): Token sequences for each sample
masks (list[list[int]]): Mask sequences (1=sampled, 0=forced)
Example:
results, masks = engine.generate_batch(
prompt_tokens,
num_samples=4,
max_tokens=128,
temperature=0.7
)
for i, (tokens, mask) in enumerate(zip(results, masks)):
text = tokenizer.decode(tokens)
print(f"Sample {i}: {text}")
KV Cache
The KV cache stores key-value pairs from attention layers to avoid recomputing them for previous tokens.
Architecture
From nanochat/engine.py:83-133:
class KVCache:
"""
KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
Key differences from FA2-style cache:
- Tensors are (B, T, H, D) not (B, H, T, D)
- FA3 updates the cache in-place during flash_attn_with_kvcache
- Position tracked per batch element via cache_seqlens tensor
"""
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype):
self.batch_size = batch_size
self.max_seq_len = seq_len
self.n_layers = num_layers
self.n_heads = num_heads
self.head_dim = head_dim
# Pre-allocate cache tensors: (n_layers, B, T, H, D)
self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
# Current sequence length per batch element (FA3 needs int32)
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
Key Methods
reset(): Reset cache to empty state
get_pos(): Get current position (assumes all batch elements at same position)
get_layer_cache(layer_idx): Return (k_cache, v_cache) views for a specific layer
advance(num_tokens): Advance the cache position by num_tokens
prefill(other): Copy cached KV from another cache (used for multi-sample generation)
Prefill-then-Decode Pattern
The engine uses an efficient two-phase approach:
- Prefill: Process the entire prompt in batch=1
- Decode: Clone the KV cache for each sample and generate in parallel
From nanochat/engine.py:194-218:
# 1) Run a batch 1 prefill of the prompt tokens
m = self.model.config
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
kv_cache_prefill = KVCache(
batch_size=1,
seq_len=len(tokens),
device=device,
dtype=dtype,
**kv_model_kwargs,
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
logits = logits[:, -1, :].expand(num_samples, -1) # (num_samples, vocab_size)
# 2) Replicate the KV cache for each sample/row
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
kv_cache_decode = KVCache(
batch_size=num_samples,
seq_len=kv_length_hint,
device=device,
dtype=dtype,
**kv_model_kwargs,
)
kv_cache_decode.prefill(kv_cache_prefill)
del kv_cache_prefill # no need to keep this memory around
This approach processes the prompt once and then generates multiple diverse samples efficiently.
Token Sampling
The engine uses a custom sampling function that supports temperature and top-k sampling.
From nanochat/engine.py:135-152:
@torch.inference_mode()
def sample_next_token(logits, rng, temperature=1.0, top_k=None):
"""Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
assert temperature >= 0.0, "temperature must be non-negative"
if temperature == 0.0:
return torch.argmax(logits, dim=-1, keepdim=True)
if top_k is not None and top_k > 0:
k = min(top_k, logits.size(-1))
vals, idx = torch.topk(logits, k, dim=-1)
vals = vals / temperature
probs = F.softmax(vals, dim=-1)
choice = torch.multinomial(probs, num_samples=1, generator=rng)
return idx.gather(1, choice)
else:
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1, generator=rng)
Sampling Modes:
temperature=0.0: Greedy decoding (always pick most likely token)
temperature=1.0: Standard sampling from full distribution
temperature>1.0: More random/creative (flattens distribution)
temperature<1.0: More focused/deterministic (sharpens distribution)
top_k: Only sample from top-k most likely tokens
The engine includes built-in support for a calculator tool. When the model generates special tokens, the engine automatically evaluates expressions and injects results.
How It Works
- Model generates
<|python_start|> token
- Engine enters “python block” mode and accumulates tokens
- Model generates
<|python_end|> token
- Engine evaluates the expression using
use_calculator()
- If successful, engine forces
<|output_start|> + result + <|output_end|> tokens
- Model continues generation with the result in context
From nanochat/engine.py:251-267:
# Handle tool logic
if next_token == python_start:
state.in_python_block = True
state.python_expr_tokens = []
elif next_token == python_end and state.in_python_block:
state.in_python_block = False
if state.python_expr_tokens:
expr = self.tokenizer.decode(state.python_expr_tokens)
result = use_calculator(expr)
if result is not None:
result_tokens = self.tokenizer.encode(str(result))
state.forced_tokens.append(output_start)
state.forced_tokens.extend(result_tokens)
state.forced_tokens.append(output_end)
state.python_expr_tokens = []
elif state.in_python_block:
state.python_expr_tokens.append(next_token)
Supported Expressions
The calculator supports:
- Math expressions:
2+2, 3.14*10, 100/5
- String operations:
"hello".count("l"), "world".count("o")
Safety features:
- Timeout after 3 seconds
- No access to builtins or dangerous operations
- Disallows power operator
**
- Sanitizes input to prevent code injection
From nanochat/engine.py:47-80:
def use_calculator(expr):
"""
Evaluate a Python expression safely.
Supports both math expressions and string operations like .count()
"""
# Remove commas from numbers
expr = expr.replace(",", "")
# Check if it's a pure math expression
if all([x in "0123456789*+-/.() " for x in expr]):
if "**" in expr: # disallow power operator
return None
return eval_with_timeout(expr)
# Check if it's a string operation we support
allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
if not all([x in allowed_chars for x in expr]):
return None
# Disallow dangerous patterns
dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
'getattr', 'setattr', 'delattr', 'hasattr']
expr_lower = expr.lower()
if any(pattern in expr_lower for pattern in dangerous_patterns):
return None
# Only allow .count() method for now
if '.count(' not in expr:
return None
return eval_with_timeout(expr)
Row State Tracking
When generating multiple samples in parallel, the engine maintains per-row state to track tool use independently.
From nanochat/engine.py:155-162:
class RowState:
# Per-row state tracking during generation
def __init__(self, current_tokens=None):
self.current_tokens = current_tokens or [] # Current token sequence for this row
self.forced_tokens = deque() # Queue of tokens to force inject
self.in_python_block = False # Whether we are inside a python block
self.python_expr_tokens = [] # Tokens of the current python expression
self.completed = False # Whether this row has completed generation
Each sample maintains:
current_tokens: Full token history
forced_tokens: Queue of tokens to inject (from tool results)
in_python_block: Whether currently inside <|python_start|> … <|python_end|>
python_expr_tokens: Accumulated expression tokens
completed: Whether generation has ended for this sample
The engine includes a built-in test to verify correctness and benchmark performance.
python -m nanochat.engine
This compares the engine’s output against the model’s naive generation function and reports timing.
From nanochat/engine.py:302-357:
if __name__ == "__main__":
"""
Quick inline test to make sure that the naive/slow model.generate function
is equivalent to the faster Engine.generate function here.
"""
# Load model
model, tokenizer, meta = load_model("base", device, phase="eval")
bos_token_id = tokenizer.get_bos_token_id()
kwargs = dict(max_tokens=64, temperature=0.0)
prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
# Generate with reference implementation
generated_tokens = []
torch.cuda.synchronize()
t0 = time.time()
stream = model.generate(prompt_tokens, **kwargs)
with autocast_ctx:
for token in stream:
generated_tokens.append(token)
torch.cuda.synchronize()
t1 = time.time()
print(f"Reference time: {t1 - t0:.2f}s")
reference_ids = generated_tokens
# Generate with Engine
generated_tokens = []
engine = Engine(model, tokenizer)
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs)
torch.cuda.synchronize()
t0 = time.time()
with autocast_ctx:
for token_column, token_masks in stream:
token = token_column[0]
generated_tokens.append(token)
torch.cuda.synchronize()
t1 = time.time()
print(f"Engine time: {t1 - t0:.2f}s")
# Compare
print(f"Match: {reference_ids == generated_tokens}")
Complete Example: Multi-Sample Generation
import torch
from nanochat.engine import Engine
from nanochat.checkpoint_manager import load_model
from nanochat.common import compute_init, autodetect_device_type
# Initialize
device_type = autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
model, tokenizer, meta = load_model("sft", device, phase="eval")
engine = Engine(model, tokenizer)
# Prepare prompt
bos = tokenizer.get_bos_token_id()
user_start = tokenizer.encode_special("<|user_start|>")
user_end = tokenizer.encode_special("<|user_end|>")
assistant_start = tokenizer.encode_special("<|assistant_start|>")
tokens = [bos, user_start]
tokens.extend(tokenizer.encode("Tell me a joke"))
tokens.extend([user_end, assistant_start])
# Generate 4 different jokes in parallel
results, masks = engine.generate_batch(
tokens,
num_samples=4,
max_tokens=200,
temperature=1.0,
top_k=50,
seed=42
)
for i, (result_tokens, mask) in enumerate(zip(results, masks)):
# Only decode the assistant's response (after assistant_start)
response_start = len(tokens)
response_tokens = result_tokens[response_start:]
text = tokenizer.decode(response_tokens)
print(f"\n=== Sample {i+1} ===")
print(text)
This efficiently generates 4 diverse responses by processing the prompt once and then sampling 4 times in parallel.