whoashish115
Back to blog

What I learned making Pando

3 min readBy Ashish Kumar
MLInterpretabilityPython

Building Pando taught me a lot about designing interpretable AI systems. Here are the key takeaways after six months of heads-down work.

The core problem

When you build machine learning systems, interpretability isn't just a nice-to-have - it's the difference between trusting your model and flying blind. The catch: most interpretability methods measure different things, and conflating them leads to false confidence.

Here's a simple example. Say you have a model classifying medical images. You run SHAP and get feature attributions that look sensible. But does that tell you why the model makes mistakes on edge cases? Not necessarily.

import shap
import torch

# Standard SHAP explanation
explainer = shap.DeepExplainer(model, background)
shap_values = explainer.shap_values(test_inputs)

# This tells you WHICH features matter most...
# ...but NOT whether those features generalize

The four axes of interpretability

After surveying the literature and running our own ablations, I settled on evaluating along four axes:

1. Feature importance rankings

Which inputs does the model attend to? Methods like SHAP, LIME, and integrated gradients all answer this, but they disagree more than you'd expect.

Method Speed Faithfulness Stability
SHAP (DeepExplainer) Fast High High
LIME Slow Medium Low
Integrated Gradients Medium High Medium
Attention weights Very fast Low High

Key insight: Faithfulness and stability are often in tension. Pick the method that matches your use case.

2. Attention visualization

For transformer-based architectures, attention maps are tempting but misleading. Raw attention weights ≠ importance.

def get_attention_rollout(model, inputs):
    """
    Attention rollout (Abnar & Zuidema 2020) composes
    attention across all layers for a more faithful map.
    """
    hooks = []
    attention_maps = []
    
    def hook_fn(module, input, output):
        attention_maps.append(output[1].detach())
    
    for layer in model.encoder.layers:
        hooks.append(layer.self_attn.register_forward_hook(hook_fn))
    
    with torch.no_grad():
        _ = model(**inputs)
    
    # Remove hooks
    for h in hooks: h.remove()
    
    # Compute rollout
    rollout = attention_maps[0]
    for attn in attention_maps[1:]:
        rollout = torch.bmm(attn, rollout)
    
    return rollout

3. Decision path tracing

For tree-based models or models with discrete bottlenecks, you can trace the actual path taken. This is more expensive but gives you ground truth.

4. Counterfactual analysis

"What's the minimal change that flips the prediction?" This is my favourite axis because it reveals the model's decision boundary in a human-understandable way.

def find_counterfactual(model, x, target_class, n_steps=100):
    """Gradient-based counterfactual search."""
    x_cf = x.clone().requires_grad_(True)
    optimizer = torch.optim.Adam([x_cf], lr=0.01)
    
    for step in range(n_steps):
        optimizer.zero_grad()
        logits = model(x_cf)
        
        # Maximize target class probability
        loss = -logits[0, target_class]
        
        # Minimize distance from original
        loss += 0.1 * torch.norm(x_cf - x)
        
        loss.backward()
        optimizer.step()
    
    return x_cf.detach()

What we got wrong initially

We spent the first two months optimizing for a single interpretability metric. Big mistake. Models that score well on feature attribution faithfulness often score poorly on human-subject comprehensibility studies. The literature warned us - we didn't listen.

The fix: treat interpretability as a multi-objective problem and be explicit about the trade-offs you're making.

Takeaways

  1. No single method is sufficient. Use at least two approaches from different axes.
  2. Faithfulness ≠ usefulness. A perfectly faithful explanation may be incomprehensible.
  3. Test on failure cases. Interpretability methods that look great on in-distribution examples often break down exactly when you need them most.
  4. Log everything. Interpretability research generates a lot of qualitative output - you'll want to go back to earlier runs.

The code is on GitHub if you want to dig in. Happy to answer questions.