!pip install transformers datasets torchviz
3 Superposition in language models
In this section, we will explore how language models encode a variety of concepts in the same sets of neurons.
Let’s start with a simple language model. We can use a correlation matrix to see a heatmap of activations for each hidden neuron in the input sequence.
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
# Set random seed for reproducibility
42)
torch.manual_seed(
# Define a simple language model
class SimpleLanguageModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.linear = nn.Linear(embedding_dim, hidden_dim)
self.output = nn.Linear(hidden_dim, vocab_size)
def forward(self, x):
= self.embedding(x)
x = torch.relu(self.linear(x))
x return self.output(x)
# Parameters
= 1000
vocab_size = 50
embedding_dim = 20
hidden_dim = 5
num_concepts
# Create model
= SimpleLanguageModel(vocab_size, embedding_dim, hidden_dim)
model
# Generate random input
= torch.randint(0, vocab_size, (100,))
input_ids
# Get hidden layer activations
with torch.no_grad():
= model.embedding(input_ids)
embeddings = torch.relu(model.linear(embeddings))
hidden_activations
# Simulate concept activations
= torch.rand((num_concepts, hidden_dim))
concept_activations
# Visualize superposition
=(12, 6))
plt.figure(figsize
# Plot hidden neuron activations
1, 2, 1)
plt.subplot(='auto', cmap='viridis')
plt.imshow(hidden_activations.T, aspect'Hidden Neuron Activations')
plt.title('Sequence Position')
plt.xlabel('Hidden Neuron')
plt.ylabel(
# Plot concept activations
1, 2, 2)
plt.subplot(='auto', cmap='viridis')
plt.imshow(concept_activations, aspect'Concept Activations')
plt.title('Hidden Neuron')
plt.xlabel('Concept')
plt.ylabel(
plt.tight_layout()
plt.show()
# Print correlation matrix
= torch.corrcoef(torch.cat([hidden_activations.mean(dim=0).unsqueeze(0), concept_activations]))
correlation_matrix print("Correlation Matrix:")
print(correlation_matrix)
Correlation Matrix:
tensor([[ 1.0000, 0.2627, -0.0627, 0.0544, 0.1577, -0.2853],
[ 0.2627, 1.0000, 0.0604, -0.0051, 0.5944, 0.3043],
[-0.0627, 0.0604, 1.0000, 0.0394, 0.1046, -0.1701],
[ 0.0544, -0.0051, 0.0394, 1.0000, -0.2142, 0.0359],
[ 0.1577, 0.5944, 0.1046, -0.2142, 1.0000, -0.1420],
[-0.2853, 0.3043, -0.1701, 0.0359, -0.1420, 1.0000]])
3.0.1 Interpretation
In these visualizations, we see multiple concepts encoded in the same set of neurons.
In the first (left) subplot, we a heatmap of activations for each hidden neuron in the input sequence.
- Each row represents a single neuron in the hidden layer
- Each column represents a position in the input sequence
- The color intensity indicates the activation strength of each neuron at each position, with brighter colors showing higher activation
In the second, concept activations visualization, we see a heatmap showing activation patterns for different concepts across the hidden neurons.
- Rows represent different concepts
- Columns represent neurons in the hidden layer
- Color intensity indicates how strongly each concept is associated with each hidden neuron
- Brighter colors indicate stronger activations
3.0.2 The key idea
These two visualizations overlap in hidden neuron space, so the same set of neurons are encoding both sequence and concept information simultaneously. That these two aspects coexist in the same hidden layer demonstrates superposition.
import torch
import torch.nn as nn
import numpy as np
from transformers import DistilBertTokenizer, DistilBertModel
from datasets import load_dataset
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
# Check if CUDA is available
= torch.device("cuda" if torch.cuda.is_available() else "cpu")
device print(f"Using device: {device}")
# Load pre-trained model and tokenizer
= "distilbert-base-uncased"
model_name = DistilBertTokenizer.from_pretrained(model_name)
tokenizer = DistilBertModel.from_pretrained(model_name).to(device)
model
# Load a subset of the GLUE dataset (SST-2 for sentiment analysis)
= load_dataset("glue", "sst2", split="train[:1000]")
dataset
# Tokenize the dataset
def tokenize_function(examples):
return tokenizer(examples["sentence"], padding="max_length", truncation=True, max_length=128, return_tensors="pt")
= dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
tokenized_dataset "torch")
tokenized_dataset.set_format(
# Define linguistic concepts (simple version)
= {
concepts "positive": ["good", "great", "excellent", "wonderful", "fantastic"],
"negative": ["bad", "terrible", "awful", "horrible", "poor"],
"neutral": ["okay", "fine", "average", "mediocre", "so-so"]
}
# Function to get hidden states
def get_hidden_states(batch):
= {k: v.to(device) for k, v in batch.items() if k in ['input_ids', 'attention_mask']}
inputs with torch.no_grad():
= model(**inputs)
outputs return outputs.last_hidden_state.cpu().numpy()
# Get hidden states for the dataset
= []
hidden_states for i in range(0, len(tokenized_dataset), 32):
= tokenized_dataset[i:i+32]
batch
hidden_states.append(get_hidden_states(batch))
= np.concatenate(hidden_states, axis=0)
hidden_states
# Calculate average hidden state for each input
= np.mean(hidden_states, axis=1)
avg_hidden_states
# Perform PCA
= PCA(n_components=2)
pca = pca.fit_transform(avg_hidden_states)
reduced_states
# Visualize the reduced hidden states
=(12, 8))
plt.figure(figsize= plt.scatter(reduced_states[:, 0], reduced_states[:, 1], c=dataset["label"], cmap="coolwarm", alpha=0.6)
scatter
plt.colorbar(scatter)"PCA of Average Hidden States Colored by Sentiment")
plt.title("First Principal Component")
plt.xlabel("Second Principal Component")
plt.ylabel(
plt.show()
# Function to get concept embeddings
def get_concept_embeddings(concepts):
= {}
concept_embeddings for concept, words in concepts.items():
= []
embeddings for word in words:
= tokenizer(word, return_tensors="pt").to(device)
inputs with torch.no_grad():
= model(**inputs)
outputs =1).cpu().numpy())
embeddings.append(outputs.last_hidden_state.mean(dim= np.mean(embeddings, axis=0).flatten()
concept_embeddings[concept] return concept_embeddings
= get_concept_embeddings(concepts)
concept_embeddings
# Calculate correlation between average hidden states and concept embeddings
= {}
correlations for concept, embedding in concept_embeddings.items():
= np.mean([np.corrcoef(avg_state, embedding)[0, 1] for avg_state in avg_hidden_states])
corr = corr
correlations[concept]
# Visualize correlations
=(10, 6))
plt.figure(figsize=list(correlations.keys()), y=list(correlations.values()))
sns.barplot(x"Average Correlation between Hidden States and Concept Embeddings")
plt.title("Correlation")
plt.ylabel(
plt.show()
print("Correlations:", correlations)
Using device: cuda
Correlations: {'positive': 0.33541225356469684, 'negative': 0.3389937145819827, 'neutral': 0.37208967616205857}
3.0.3 Interpretation
In the PCA plot, we see clusters of points with different colors, suggesting the model is distinguishing between different sentiments in its hidden representations.
In the correlation bar plot, the higher values suggest hidden states are more tightly correlated with that particular concept. The high correlations for multiple concepts suggests the model is encoding multiple concepts simultaneously in its hidden states.