Model Distillation
Compress a large teacher model into a compact student model that retains near-identical accuracy at a fraction of the inference cost.
Overview
Knowledge distillation trains a smaller student network to mimic the softened output distribution of a larger teacher. Instead of training on hard labels alone, the student learns from the teacher's logits, capturing inter-class relationships that one-hot targets discard. The result is a model that runs on edge hardware, mobile devices, or CPU-only inference while preserving most of the teacher's predictive power.
Temperature Scaling
The teacher produces logits z. A temperature parameter T > 1 softens the softmax, revealing dark knowledge about which classes the teacher considers plausible alternatives. The student minimizes a weighted sum of the distillation loss (KL divergence between softened teacher and student outputs) and the standard cross-entropy loss against ground-truth labels.
Loss Function
L = α · T² · KL( softmax(z_t / T) || softmax(z_s / T) ) + (1 - α) · CE( softmax(z_s), y_true )
The T² factor compensates for gradient scaling introduced by temperature. Typical values: T ∈ [2, 8], α ∈ [0.7, 0.9].
Architecture Choices
- Same vocabulary: Student shares the teacher's tokenizer and output head dimensions.
- Layer reduction: Drop every other transformer block; initialize student layers from teacher weights for faster convergence.
- Hidden-dim shrink: Reduce embedding dimension and FFN intermediate size proportionally.
- Attention head pruning:Fewer heads per layer; merge or discard based on head-importance scoring.
Training Recipe
- Freeze the teacher and run forward passes over the full dataset to cache softened logits.
- Initialize the student with a subset of teacher layers where dimensions match.
- Train with AdamW, cosine schedule, linear warmup over 5% of steps.
- Monitor validation perplexity; stop when student loss plateaus within 2% of teacher.
- Fine-tune final epoch with
α=0(hard labels only) to sharpen predictions.
Expected Results
A well-tuned student with 30–40% of the teacher's parameters typically retains 95–98% of the teacher's accuracy on benchmarks. Inference latency drops proportionally to parameter count, and memory footprint shrinks enough for on-device deployment. The trade-off surface is smooth: smaller students lose fidelity gradually, so you can pick the smallest model that meets your accuracy threshold.