Problem 1: Multi-Head Attention - Multi-Digit Addition

Implement scaled dot-product attention and multi-head attention from scratch. Train on a multi-digit addition task and analyze what different attention heads learn.

Part A: Dataset and Data Loading

You will work with a multi-digit addition dataset where the model must add two 3-digit numbers with carry propagation.

Generate the dataset:

cd problem1
python generate_data.py --seed 641 --num-digits 3

The dataset contains: - Input: Two 3-digit numbers separated by ‘+’ token (e.g., [3, 4, 7, +, 1, 5, 9]) - Output: Sum padded to 4 digits (e.g., [0, 5, 0, 5] for 505) - Training samples: 10,000 - Validation samples: 2,000 - Test samples: 2,000

Example training samples:

{"input": [5, 0, 7, 10, 1, 5, 9], "target": [0, 6, 6, 6]}
{"input": [1, 5, 5, 10, 4, 9, 1], "target": [0, 6, 4, 6]}
{"input": [3, 9, 4, 10, 3, 1, 6], "target": [0, 7, 1, 0]}

Where token 10 represents the ‘+’ operator (507 + 159 = 666, 155 + 491 = 646, 394 + 316 = 710).

The starter code provides dataset.py with the data loader:

Part B: Attention Implementation

Implement attention.py with the core attention mechanisms:

Part C: Model Architecture

Implement a sequence-to-sequence transformer in model.py:

Part D: Training

Train your model using the provided training script:

Part E: Attention Analysis

Implement analyze.py to extract and visualize attention patterns:

Deliverables

Your problem1/ directory must contain:

  1. All code files as specified above
  2. results/training_log.json with loss curves and accuracy metrics
  3. results/best_model.pth - saved model weights
  4. results/attention_patterns/ containing:
    • Heatmaps for each attention head
    • Example visualizations on test cases
  5. results/head_analysis/ containing:
    • Head ablation results
    • Head importance rankings

Your report must include analysis of:

  • Attention pattern visualizations from at least 4 different heads
  • Head ablation study: which heads are critical vs redundant?
  • Discussion: How do attention heads specialize for carry propagation?
  • Quantitative results: percentage of heads that can be pruned with minimal accuracy loss