Finding What to Fine-Tune with Ablation
I recently learned about ablation and found it compelling, especially for fine-tuning. Here's how I understand it, with practical examples.
Suppose we have a prompt like "a cat behind a chair" for an image generator. The model nails the cat and the chair but messes up the spatial relationship—the behind part.
Ablation's core idea: systematically disable parts of a model and see how outputs change. That reveals what each part does. It works for classifiers, image generators, anything. Whenever you need to understand a model's internals, you ablate.
Load a model in PyTorch and explore its structure:
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev")
print(pipe.transformer)This prints a nested table of contents—every named component. You might see something like:
FluxTransformer(
(time_embed): Timesteps(...)
(text_embed): Linear(...)
(transformer_blocks): ModuleList(
(0): FluxBlock(
(attn): Attention(
(to_q): Linear(3072, 3072)
(to_k): Linear(3072, 3072)
(to_v): Linear(3072, 3072)
(to_out): Linear(3072, 3072)
)
(ff): FeedForward(
(net): Sequential(
(0): Linear(3072, 12288)
(1): GELU()
(2): Linear(12288, 3072)
)
)
)
(1): FluxBlock(...)
(2): FluxBlock(...)
... maybe 20-40 of these ...
)
)Each named item—to_q, to_k, ff, the whole FluxBlock(0)—is a “part” you can ablate.
Say you want to test transformer block 15. You need a hook that makes the block a no-op. The catch: transformer blocks use residual connections. Each block's output includes both its own contribution and a pass-through of the input. If you zeroed the output, you'd wipe out the entire residual stream—all the information accumulated by blocks 0–14—not just block 15's contribution. Instead, return the input unchanged so block 15 becomes a pass-through:
def skip_block_15(module, input, output):
return input[0]
handle = pipe.transformer.transformer_blocks[15].register_forward_hook(skip_block_15)The model runs end to end, but block 15 contributes nothing. Data flows in and the same data flows out, as if the block didn't exist.
test_prompts = [
"a cat standing behind a chair",
"a ball underneath a table",
"a person in front of a building",
"a dog sitting on top of a car",
]
for prompt in test_prompts:
image = pipe(prompt).images[0]
image.save(f"ablated_block15_{prompt[:20]}.png")
handle.remove()
for prompt in test_prompts:
image = pipe(prompt).images[0]
image.save(f"baseline_{prompt[:20]}.png")With block 15 disabled, spatial relationships may break but colors stay fine. Or everything looks nearly identical (block 15 wasn't important). Or images turn to garbage (block 15 is critical for coherence).
If block 15 mattered, go deeper—ablate individual sub-layers. Unlike full blocks, sub-layers like to_q don't have their own residual connections, so zeroing the output is the right approach here. With all-zero queries, the attention can't be selective about what to attend to—softmax(0) gives uniform weights, so the output becomes a meaningless average rather than a focused combination:
def kill_query(module, input, output):
return torch.zeros_like(output)
handle = pipe.transformer.transformer_blocks[15].attn.to_q.register_forward_hook(kill_query)Build a spreadsheet:
| Component disabled | Spatial accuracy | Color accuracy | Overall quality |
|---|---|---|---|
| None (baseline) | 3/10 | 8/10 | 7/10 |
| Block 5 | 3/10 | 8/10 | 7/10 |
| Block 15 | 1/10 | 8/10 | 4/10 |
| Block 15 attn.to_q | 2/10 | 8/10 | 5/10 |
| Block 15 attn.to_k | 3/10 | 7/10 | 6/10 |
| Block 15 ff | 3/10 | 8/10 | 7/10 |
Block 15 clearly matters for spatial reasoning—skipping it entirely drops accuracy from 3 to 1. Within that block, to_q is the single biggest individual contributor (3→2), though the full degradation isn't explained by any one sub-layer alone, suggesting they interact. Still, this gives us a strong starting point for where to focus fine-tuning.
It varies widely. A 12-billion-parameter model like Flux might have 30–40 transformer blocks, each containing several sub-layers (query, key, value projections, output projection, feed-forward layers).
- Block level: 30–40 things to test. Each is a named module you can hook directly.
- Sub-layer level: a few hundred (every
to_q,to_k,to_v,to_out, and feed-forward layer across all blocks). - Attention head level: ~700–1,000. These require a different technique—multi-head attention is implemented as a single large matrix that gets reshaped into heads internally, so you can't hook individual heads as named modules. Instead, you zero specific slices of the output tensor.
- Parameter level: billions—no one ablates at that scale.
In practice, start at the block level (30–40 experiments), find the 3–5 relevant ones, then drill into the sub-layers within just those blocks.
No. No one sat down and decided “Block 15 handles spatial reasoning, Block 8 handles color.” It doesn't work that way.
Architecturally, blocks are identical copies—same attention, same feed-forward wiring. Then you train on millions of images with text descriptions. Through gradient descent, weights inside each block adjust, and different blocks organically develop informal specializations. No roles are assigned; they emerge.
Some common patterns across many models:
- Early blocks handle low-level features (edges, textures, color).
- Middle blocks handle composition, object relationships, layout.
- Late blocks handle fine details and coherence.
But these are fuzzy tendencies, not clean divisions. Some capabilities distribute across many blocks. Spatial reasoning might be exactly that kind of distributed capability, which is why it's hard to fix.
The field dedicated to answering “what do the parts do?” is mechanistic interpretability.
One important caveat: ablation only tells you what's involved in a capability. It's purely subtractive—you remove a piece and observe what breaks. That proves the piece participates, but it doesn't prove that fine-tuning that piece alone will fix the capability. Spatial reasoning might require coordinated changes across multiple blocks, and block 15 might be just one participant in a distributed computation.
Think of it like removing a fuse from a car and finding that the headlights stop working. That tells you the fuse is in the headlight circuit—but if the headlights are dim, replacing that one fuse might not fix the problem. You might also need a new bulb.
In practice, ablation narrows your search space from "all 12 billion parameters" to "a few specific blocks." That's enormously valuable. But the fine-tuning step that follows is still somewhat experimental—you're making an informed bet, not a guaranteed fix. The example below is simplified to show the workflow.
Suppose ablation pinned spatial reasoning to block 15's attention. Here's how to use that for targeted fine-tuning.
Use Blender's API to generate thousands of scenes with precise spatial relationships:
scenes = [
{"prompt": "a cat standing behind a chair",
"image": render_scene(cat_pos=(0,0,-2), chair_pos=(0,0,0), camera="front")},
{"prompt": "a ball underneath a table",
"image": render_scene(ball_pos=(0,-1,0), table_pos=(0,0,0), camera="front")},
# ... thousands more with varied objects, positions, angles
]Freeze the whole model, then unfreeze only the identified component:
for param in pipe.transformer.parameters():
param.requires_grad = False
for param in pipe.transformer.transformer_blocks[15].attn.parameters():
param.requires_grad = True
for param in pipe.transformer.transformer_blocks[14].attn.parameters():
param.requires_grad = TrueOnly these weights update during training. Everything else remains untouched. This is ablation's payoff—otherwise, you'd retrain all 12 billion parameters, which is costly and risks degrading other capabilities.
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=[
"transformer_blocks.14.attn.to_q",
"transformer_blocks.14.attn.to_v",
"transformer_blocks.15.attn.to_q",
"transformer_blocks.15.attn.to_v",
],
lora_dropout=0.05,
)
model = get_peft_model(pipe.transformer, lora_config)
model.print_trainable_parameters()
# Output: "trainable params: 2,359,296 || all params: 12,000,000,000 || 0.02%"Instead of directly modifying weights, we attach small trainable adapters to just the crucial blocks. We're training only 0.02% of the model.
Standard diffusion training: add noise, predict it, measure error, and update the unfrozen weights:
from torch.optim import AdamW
optimizer = AdamW(model.parameters(), lr=1e-5)
for epoch in range(10):
for batch in dataloader:
images = batch["images"]
prompts = batch["prompts"]
noise = torch.randn_like(images)
timestep = torch.randint(0, 1000, (images.shape[0],))
noisy_images = add_noise(images, noise, timestep)
predicted_noise = model(noisy_images, timestep, prompts)
loss = F.mse_loss(predicted_noise, noise)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"Epoch {epoch}, Loss: {loss.item()}")test_prompts = [
"a cat standing behind a chair",
"a dog in front of a fireplace",
"a cup underneath a shelf",
]
for p in test_prompts:
original = pipe_original(p).images[0]
original.save(f"original_{p[:20]}.png")
for p in test_prompts:
finetuned = pipe_finetuned(p).images[0]
finetuned.save(f"finetuned_{p[:20]}.png")Did spatial accuracy improve without hurting anything else? If so, we succeeded.
Without ablation, we'd blindly apply LoRA to all blocks—more compute, more risk, and less insight. Knowing it's block 15's attention lets us be precise and efficient.