Recipe

Model Distillation

Compress a large teacher model into a compact student model that retains near-identical accuracy at a fraction of the inference cost.

Overview

Knowledge distillation trains a smaller student network to mimic the softened output distribution of a larger teacher. Instead of training on hard labels alone, the student learns from the teacher's logits, capturing inter-class relationships that one-hot targets discard. The result is a model that runs on edge hardware, mobile devices, or CPU-only inference while preserving most of the teacher's predictive power.

Temperature Scaling

The teacher produces logits z. A temperature parameter T > 1 softens the softmax, revealing dark knowledge about which classes the teacher considers plausible alternatives. The student minimizes a weighted sum of the distillation loss (KL divergence between softened teacher and student outputs) and the standard cross-entropy loss against ground-truth labels.

Loss Function

L = α · T² · KL( softmax(z_t / T) || softmax(z_s / T) )
  + (1 - α) · CE( softmax(z_s), y_true )

The factor compensates for gradient scaling introduced by temperature. Typical values: T ∈ [2, 8], α ∈ [0.7, 0.9].

Architecture Choices

  • Same vocabulary: Student shares the teacher's tokenizer and output head dimensions.
  • Layer reduction: Drop every other transformer block; initialize student layers from teacher weights for faster convergence.
  • Hidden-dim shrink: Reduce embedding dimension and FFN intermediate size proportionally.
  • Attention head pruning:Fewer heads per layer; merge or discard based on head-importance scoring.

Training Recipe

  1. Freeze the teacher and run forward passes over the full dataset to cache softened logits.
  2. Initialize the student with a subset of teacher layers where dimensions match.
  3. Train with AdamW, cosine schedule, linear warmup over 5% of steps.
  4. Monitor validation perplexity; stop when student loss plateaus within 2% of teacher.
  5. Fine-tune final epoch with α=0 (hard labels only) to sharpen predictions.

Expected Results

A well-tuned student with 30–40% of the teacher's parameters typically retains 95–98% of the teacher's accuracy on benchmarks. Inference latency drops proportionally to parameter count, and memory footprint shrinks enough for on-device deployment. The trade-off surface is smooth: smaller students lose fidelity gradually, so you can pick the smallest model that meets your accuracy threshold.