Problem 1: Multi-Head Attention - Multi-Digit Addition
Implement scaled dot-product attention and multi-head attention from scratch. Train on a multi-digit addition task and analyze what different attention heads learn.
Part A: Dataset and Data Loading
You will work with a multi-digit addition dataset where the model must add two 3-digit numbers with carry propagation.
Generate the dataset:
cd problem1
python generate_data.py --seed 641 --num-digits 3The dataset contains: - Input: Two 3-digit numbers separated by ‘+’ token (e.g., [3, 4, 7, +, 1, 5, 9]) - Output: Sum padded to 4 digits (e.g., [0, 5, 0, 5] for 505) - Training samples: 10,000 - Validation samples: 2,000 - Test samples: 2,000
Example training samples:
{"input": [5, 0, 7, 10, 1, 5, 9], "target": [0, 6, 6, 6]}
{"input": [1, 5, 5, 10, 4, 9, 1], "target": [0, 6, 4, 6]}
{"input": [3, 9, 4, 10, 3, 1, 6], "target": [0, 7, 1, 0]}Where token 10 represents the ‘+’ operator (507 + 159 = 666, 155 + 491 = 646, 394 + 316 = 710).
The starter code provides dataset.py with the data loader:
Part B: Attention Implementation
Implement attention.py with the core attention mechanisms:
Part C: Model Architecture
Implement a sequence-to-sequence transformer in model.py:
Part D: Training
Train your model using the provided training script:
Part E: Attention Analysis
Implement analyze.py to extract and visualize attention patterns:
Deliverables
Your problem1/ directory must contain:
- All code files as specified above
results/training_log.jsonwith loss curves and accuracy metricsresults/best_model.pth- saved model weightsresults/attention_patterns/containing:- Heatmaps for each attention head
- Example visualizations on test cases
results/head_analysis/containing:- Head ablation results
- Head importance rankings
Your report must include analysis of:
- Attention pattern visualizations from at least 4 different heads
- Head ablation study: which heads are critical vs redundant?
- Discussion: How do attention heads specialize for carry propagation?
- Quantitative results: percentage of heads that can be pruned with minimal accuracy loss