In this post I’ll walk through my current experiment on structured pruning for large language models, inspired by WANDA, but extended to adpative version with pruning ratio for each layer searched by block-wise Bayesian optimization. The goal is to find pruning ratios that preserve performance, mainly coding in Python, while reducing model size. For easy proof-of-concept and my personal use, I pick qwen3-0.6B as the baseline model, a versitile instructed COT model. It’s a challenge to prune such a small language model without harm it’s ability significantly, but I believe it’s doable. For readers who are unclear about pruning or curious about other LLM compression techniques, please refer to this review.
Step 1: Row Pruning of FFN Layers Link to heading
I start by pruning rows of the FFN weight matrices (i.e., MLP neurons). The main reason that I leave attention matrice untouched is, qwen3 normalize the query and key heads before applying rotary position embedding in it’s attention module:
class Qwen3Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(...):
...
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
...
def forward(...):
...
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
...
where Qwen3RMSNorm is Root Mean Square Normalization (RMSNorm) basically, and here are some obstacles stopped me from pruning attention head or weight matrix rows in a simple way:
- Normalization parameter mismatch: Each attention head has its own normalization parameters in Qwen3. When you prune heads, these normalization weights become misaligned.
- RoPE dimension dependencies: The rotary position embeddings expect consistent head dimensions, which can break after pruning.
- Statistical disruption: The per-head normalization statistics change when heads are removed, potentially destabilizing the model.
To avoid harming stability, the first layer is left untouched. Row pruning makes sense here because removing a row corresponds to removing a neuron, which cleanly maps to reducing dimensionality in the next layer. Instead of physically removing neurons during the search, I mask out rows with small WANDA scores. This allows fast experimentation without modifying the model architecture. Later, once I have an optimized pruning schedule, I can hard-remove the neurons and adjust downstream weight matrices accordingly.
Step 2: Calibration Dataset Link to heading
For calibration, I sample 128 examples each from MBPP (Python code generation benchmark) and WizardCoder training sets. This gives me 256 samples total for WANDA pruning. The calibration data is only used to measure neuron importance — not for retraining.
Step 3: Statistical Layer-Wise Comparison Link to heading
To preserve the statistical traits of layer output after pruning, instead of evaluating KL divergence or consine distance at the logits level, I designed a statistical comparison function between the original and pruned model’s layer activations. For each layer, I compute:
- Mean activation L2 shift (absolute + relative)
- RMS activation shift (captures WANDA’s core idea about activation magnitude)
- Cosine similarity between mean activations (directional change)
These metrics are combined into a composite score:
composite_score = 0.6 * rms_score + 0.25 * mean_score + 0.15 * direction_score
This gives a per-layer sensitivity signal while only need to store a small layer statistics file for original model(<0.5mb).
Step 4: Bayesian Optimization with Optuna Link to heading
Why Bayesian Optimization? Choosing how much to prune per layer is not obvious:
- Early layers are often more sensitive,
- Later layers may tolerate more pruning,
- Interactions between blocks mean pruning decisions in one layer can amplify or dampen effects downstream. Naive approaches like uniform pruning or manually tuning ratios don’t capture this. Instead, I use Bayesian Optimization (BO) to automatically discover good pruning schedules. BO treats pruning as a black-box optimization problem:
- Input: pruning ratios for a group of layers (the decision variables).
- Output: a score combining model performance (Pass@1) and structural stability (layer statistics). The BO loop uses a probabilistic model to suggest pruning ratios likely to improve the tradeoff between compression and accuracy, requiring far fewer trials than random or grid search. More info about BO and TPE estimator can be found in my thesis.
This way, the pruning schedule is learned adaptively, not hand-crafted — a key step toward globally-aware structured pruning. To make it easier to conduct search, I shrink the search space dimension and formulate pruning as a block-wise optimization problem. Each block contains 4 layers, 7 blocks in total for a 28-layer qwen3-0.6B, and BO assigns a pruning ratio to each layers for each trial, searching best config block by block. The single-objective function in Optuna is constructed as:
obj = cd - pruned_params - pr
where:
cd
: composite score (sensitivity) for one block + final layerpr
: pass@1 on one MBPP validation samplepruned_params
: pruning ratio in the current block Also add some coefficients to balance the effect of these terms
For each block, I run 40 iterations using the default TPE sampler in Optuna to search pruning ratios for 4 layers at once.
Step 5: Validation & Comparison Link to heading
After searching, I test the best pruning ratio config on the full validation set:
- Perplexity (PPL)
- Pass@1 on MBPP
Note the ppl is a general evaluation metric, with very little help guiding pruning on coding dataset. I use Pass@1 during BO(only one sample from validation set in MBPP) and evaluation (25 samples from MBPP val dataset that baseline passed the tests). Then I compare against:
- Baseline (unpruned model)
- Uniform pruning (same average pruning ratio applied across layers)
This lets me quantify the benefit of adaptive, block-wise Bayesian pruning compared to simple heuristics. Besides, I also test output generated by all 3 models with a classic quick sort task, for a more intuitive and straight comparison.
prompt: Write a quick sort in python.
- baseline qwen3-0.6b - ppl: 3.9956388652324675 on val wizard code data 20 samples - Pass@1 on 25/90 MBPP validation samples
generated content ✅ : Sure! Here’s a simple quick sort implementation in Python:
def quick_sort(arr):
if len(arr) <= 1:
return arr
pivot = arr[len(arr) // 2]
left = [x for x in arr if x < pivot]
right = [x for x in arr if x > pivot]
return quick_sort(left) + [pivot] + quick_sort(right)
# Example usage:
arr = [3, 2, 1, 5, 4, 6]
sorted_arr = quick_sort(arr)
print(sorted_arr)
- BO searched model - ppl: 4.213513535261154 - Pass@1 on 25 MBPP validation samples: 0.84 - pruned ratio: 3.60
generated content ✅ : Sure! Here’s a quick sort implementation in Python, which is a classic sorting algorithm:
def quick_sort(arr):
if len(arr) <= 1:
return arr
mid = arr[len(arr) // 2]
left = [num for num in arr if num < mid]
right = [num for num in arr if num > mid]
return quick_sort(left) + [mid] + quick_sort(right)
# Example usage:
arr = [6, 5, 3, 2, 1]
sorted_arr = quick_sort(arr)
print(sorted_arr)
- Average pruning - ppl: 4.302274990081787 - Pass@1 on 25 MBPP validation samples: 0.76 - pruned ratio: 3.51 (0.13 per layer except the first layer)
generated content ❌ : Sure! Here’s a quick sort implementation in Python:
def quicksort(arr):
if len(arr) <= 1:
return arr
pivot = arr[0]
left = [x for x in arr if x != pivot]
right = [x for x in arr if x == pivot]
return quicksort(left) + [pivot] + quicksort(right)
# Example usage
arr = [3, 2, 1, 4, 5, 6, 7, 8, 9]
sorted_arr = quicksort(arr)
print(sorted_arr)
✨ Key Takeaways Link to heading
- Row pruning of FFN layers works well as a structured pruning strategy.
- Instead of expensive KL divergence, layer-wise statistical metrics provide a cheaper, interpretable signal.
- Bayesian optimization can adaptively allocate pruning budgets across layers, outperforming uniform pruning.
- Block-wise search balances granularity and search complexity.
This framework is a step toward globally-aware structured pruning where pruning ratios are learned, not hand-tuned.
Next up: I plan to test multi-objective Bayesian optimization (KL vs sparsity) and investigate better evaluation methods for coding LLM, also refine the prompt and enhance context engineering with AST.