!pip install jupyterlab-nvidia-nsight
4 Profiling and optimizing PyTorch training
(Make sure you are using the free T4 runtime in Colab)
Since using GPUs is the most expensive step in ML training and inference, no small amount of work goes into optimizing their use. In the real world, very few organizations and developers work on low-level kernel optimizations. They typically work further up the stack with frameworks such as PyTorch, leaving PyTorch’s optimizations to those working on its backend (which of course uses CUDA).
To give us a lens into the operations being performed on the accelerator and their efficiency in this scenario, there are a variety of profiling tools available. In this notebook, we will explore the use of Nvidia’s Nsight. The software is available as a desktop application and command line tool.
4.0.1 Install Nsight tools
Since Colabs are essentially a linux-based virtual machine, we can use apt get
to install the Nvidia tools
%%bash
apt update-y --no-install-recommends gnupg
apt install "deb http://developer.download.nvidia.com/devtools/repos/ubuntu$(source /etc/lsb-release; echo "$DISTRIB_RELEASE" | tr -d .)/$(dpkg --print-architecture) /" | tee /etc/apt/sources.list.d/nvidia-devtools.list
echo -key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
apt
apt update-systems-cli apt install nsight
4.0.2 Check the installation
!nsys status -e
Timestamp counter supported: Yes
CPU Profiling Environment Check
Root privilege: enabled
Linux Kernel Paranoid Level = 2
Linux Distribution = Ubuntu
Linux Kernel Version = 6.1.85+: OK
Linux perf_event_open syscall available: OK
Sampling trigger event available: OK
Intel(c) Last Branch Record support: Not Available
CPU Profiling Environment (process-tree): OK
CPU Profiling Environment (system-wide): OK
See the product documentation at https://docs.nvidia.com/nsight-systems for more information,
including information on how to set the Linux Kernel Paranoid Level.
4.0.3 Simple attention
Here’s our basic attention mechanism that computes query, key, and value matrices to generate weighted representations of input data. The SimpleTransformer class combines this attention mechanism with layer normalization in a residual connection setup.
We will include profiling code to measure CPU and GPU performance metrics when running the model on sample input data.
%%writefile profiler.py
import torch
import torch.nn as nn
class SimpleAttention(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.query = nn.Linear(embed_dim, embed_dim)
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
= self.query(x)
q = self.key(x)
k = self.value(x)
v
= torch.matmul(q, k.transpose(-2, -1))
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_weights
return torch.matmul(attn_weights, v)
class SimpleTransformer(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.attention = SimpleAttention(embed_dim)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
= self.attention(x)
attn_output return self.norm(x + attn_output)
# Create a model and sample input
= 256
embed_dim = 100
seq_length = 32
batch_size
= SimpleTransformer(embed_dim, num_heads=1).cuda()
model = torch.randn(batch_size, seq_length, embed_dim).cuda()
sample_input
import torch.cuda.profiler as profiler
# Warm-up run
model(sample_input)
# Profile the model
with profiler.profile(activities=[profiler.ProfilerActivity.CPU, profiler.ProfilerActivity.CUDA], record_shapes=True, profile_memory=True, with_stack=True) as prof:
with profiler.record_function("model_inference"):
model(sample_input)
# Print profiling results
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
Writing profiler.py
!nsys profile --stats=true python profiler.py
Collecting data...
Traceback (most recent call last):
File "/content/profiler.py", line 46, in <module>
with profiler.profile(activities=[profiler.ProfilerActivity.CPU, profiler.ProfilerActivity.CUDA], record_shapes=True, profile_memory=True, with_stack=True) as prof:
AttributeError: module 'torch.cuda.profiler' has no attribute 'ProfilerActivity'
Generating '/tmp/nsys-report-4b28.qdstrm'
[1/8] [========================100%] report1.nsys-rep
[2/8] [========================100%] report1.sqlite
[3/8] Executing 'nvtx_sum' stats report
SKIPPED: /content/report1.sqlite does not contain NV Tools Extension (NVTX) data.
[4/8] Executing 'osrt_sum' stats report
Time (%) Total Time (ns) Num Calls Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- ------------ ------------ --------- ----------- ------------ ----------------------
76.0 1,702,896,538 29 58,720,570.3 77,869,201.0 4,006 100,150,826 44,495,375.6 poll
13.2 295,663,488 1,672 176,832.2 3,329.0 1,004 15,374,049 641,970.2 read
4.0 89,772,667 611 146,927.4 16,308.0 1,509 23,241,632 1,163,272.8 ioctl
3.9 86,647,314 5,802 14,934.0 3,006.5 1,006 13,686,743 246,663.1 stat64
1.7 38,487,406 992 38,797.8 13,551.0 2,049 844,849 101,781.4 open64
0.5 10,831,409 7,221 1,500.0 1,469.0 1,000 18,716 610.5 lstat64
0.3 6,090,781 1 6,090,781.0 6,090,781.0 6,090,781 6,090,781 0.0 nanosleep
0.2 3,767,157 76 49,567.9 10,444.5 3,410 2,163,110 247,130.7 mmap64
0.2 3,716,694 1,853 2,005.8 1,794.0 1,018 25,280 1,277.1 fstat64
0.0 596,671 4 149,167.8 49,473.0 35,192 462,533 209,267.0 sem_timedwait
0.0 521,407 74 7,046.0 3,520.5 1,697 169,539 19,585.7 fopen
0.0 333,517 20 16,675.8 9,374.5 2,751 105,415 22,298.1 mmap
0.0 303,794 16 18,987.1 12,165.5 1,402 57,334 18,133.4 write
0.0 261,847 8 32,730.9 32,683.5 23,171 39,984 5,418.0 fgets
0.0 207,125 3 69,041.7 68,447.0 66,655 72,023 2,733.0 sleep
0.0 142,461 70 2,035.2 1,419.5 1,079 13,265 1,673.8 fclose
0.0 118,808 2 59,404.0 59,404.0 56,211 62,597 4,515.6 pthread_create
0.0 114,648 9 12,738.7 13,780.0 7,881 21,544 4,325.4 munmap
0.0 92,736 12 7,728.0 4,024.0 1,623 27,163 8,106.5 pthread_cond_signal
0.0 85,302 15 5,686.8 4,267.0 2,028 22,918 5,157.5 open
0.0 33,190 5 6,638.0 5,099.0 1,635 16,932 5,952.6 fopen64
0.0 23,712 2 11,856.0 11,856.0 8,025 15,687 5,417.9 socket
0.0 10,493 1 10,493.0 10,493.0 10,493 10,493 0.0 connect
0.0 8,915 1 8,915.0 8,915.0 8,915 8,915 0.0 pthread_cond_broadcast
0.0 7,372 1 7,372.0 7,372.0 7,372 7,372 0.0 pipe2
0.0 6,383 1 6,383.0 6,383.0 6,383 6,383 0.0 getc
0.0 5,305 1 5,305.0 5,305.0 5,305 5,305 0.0 fread
0.0 5,009 2 2,504.5 2,504.5 1,829 3,180 955.3 sigaction
0.0 4,778 4 1,194.5 1,195.0 1,043 1,345 123.3 fflush
0.0 3,201 3 1,067.0 1,055.0 1,033 1,113 41.3 fcntl
0.0 2,580 1 2,580.0 2,580.0 2,580 2,580 0.0 fputs_unlocked
0.0 2,300 1 2,300.0 2,300.0 2,300 2,300 0.0 bind
0.0 1,609 1 1,609.0 1,609.0 1,609 1,609 0.0 fcntl64
0.0 1,505 1 1,505.0 1,505.0 1,505 1,505 0.0 listen
0.0 1,318 1 1,318.0 1,318.0 1,318 1,318 0.0 pthread_mutex_trylock
[5/8] Executing 'cuda_api_sum' stats report
Time (%) Total Time (ns) Num Calls Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- ------------ ------------ --------- ---------- ------------ ----------------------------
59.0 133,942,988 8 16,742,873.5 44,480.0 18,304 75,424,540 29,365,709.9 cudaLaunchKernel
19.9 45,123,527 9 5,013,725.2 21,450.0 7,039 44,149,269 14,677,890.0 cudaMemcpyAsync
19.1 43,452,758 2 21,726,379.0 21,726,379.0 5,346,866 38,105,892 23,164,129.4 cudaFree
0.8 1,898,877 6 316,479.5 216,731.0 12,598 976,777 341,108.4 cudaMalloc
0.6 1,421,966 18 78,998.1 668.0 616 1,401,568 330,071.6 cudaEventCreateWithFlags
0.3 658,905 3 219,635.0 3,212.0 2,603 653,090 375,383.2 cudaStreamIsCapturing_v10000
0.2 469,330 1,149 408.5 256.0 126 146,038 4,301.4 cuGetProcAddress_v2
0.1 182,122 9 20,235.8 7,191.0 5,845 68,490 20,841.1 cudaStreamSynchronize
0.0 7,048 3 2,349.3 2,324.0 2,234 2,490 129.9 cuInit
0.0 1,790 4 447.5 300.5 252 937 327.1 cuModuleGetLoadingMode
[6/8] Executing 'cuda_gpu_kern_sum' stats report
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- --------- --------- -------- -------- ----------- ----------------------------------------------------------------------------------------------------
61.8 632,402 3 210,800.7 210,491.0 209,947 211,964 1,043.5 volta_sgemm_128x64_tn
14.3 146,205 1 146,205.0 146,205.0 146,205 146,205 0.0 volta_sgemm_64x64_tn
11.2 114,973 1 114,973.0 114,973.0 114,973 114,973 0.0 volta_sgemm_128x64_nn
8.1 82,590 1 82,590.0 82,590.0 82,590 82,590 0.0 void at::native::<unnamed>::vectorized_layer_norm_kernel<float, float>(int, T2, const T1 *, const T…
3.1 32,127 1 32,127.0 32,127.0 32,127 32,127 0.0 void at::native::vectorized_elementwise_kernel<(int)4, at::native::CUDAFunctor_add<float>, at::deta…
1.4 14,240 1 14,240.0 14,240.0 14,240 14,240 0.0 void <unnamed>::softmax_warp_forward<float, float, float, (int)7, (bool)0, (bool)0>(T2 *, const T1 …
[7/8] Executing 'cuda_gpu_mem_time_sum' stats report
Time (%) Total Time (ns) Count Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Operation
-------- --------------- ----- -------- -------- -------- -------- ----------- ----------------------------
100.0 675,793 9 75,088.1 768.0 735 599,634 197,031.0 [CUDA memcpy Host-to-Device]
[8/8] Executing 'cuda_gpu_mem_size_sum' stats report
Total (MB) Count Avg (MB) Med (MB) Min (MB) Max (MB) StdDev (MB) Operation
---------- ----- -------- -------- -------- -------- ----------- ----------------------------
4.068 9 0.452 0.001 0.001 3.277 1.067 [CUDA memcpy Host-to-Device]
Generated:
/content/report1.nsys-rep
/content/report1.sqlite
(Numbers will differ slightly each time we run these cells)
Let’s analyze the “cuda_gpu_kern_sum” report, which shows the GPU kernel executions:
volta_sgemm_128x64_tn
(61.9% of GPU time): This is likely the matrix multiplication for computing attention weights (q * k.transpose(-2, -1)). It’s using NVIDIA’s optimized GEMM (General Matrix Multiplication) kernel. Typically the most compute-intensive operation in a. transformer model.volta_sgemm_64x64_tn
(14.2% of GPU time): This could be another part of the attention computation, possibly the final matrix multiplication with the value matrix (attn_weights * v).volta_sgemm_128x64_nn
(11.3% of GPU time): This might be the matrix multiplication in one of the linear layers (query, key, or value projection).vectorized_layer_norm_kernel
(8.1% of GPU time): This corresponds to the LayerNorm operation in the SimpleTransformer class. vectorized_elementwise_kernel (3.1% of GPU time): This could be the element-wise addition in the residual connection (x + attn_output).softmax_warp_forward
(1.4% of GPU time): This is the softmax operation applied to the attention weights.
The SimpleAttention
class operations are primarily represented by items 1, 2, 3, and 6 in this list. These operations account for about 88.8% of the GPU kernel execution time, which indicates that the attention mechanism is indeed a significant part of the computation. To optimize this, we could:
- Use the optimized attention mechanism as suggested in the tutorial (torch.nn.functional.scaled_dot_product_attention).
- Experiment with different batch sizes or sequence lengths to find the optimal configuration for your hardware.
- Consider using mixed precision (float16).
4.1 Flash Attention
Transformer models can be bottlenecked by self-attention, a mechanism which has quadratic time and memory complexity.
Standard attention uses High Bandwidth Memory (HBM) to store, read and write keys, queries and values
HBM typically has large memory and is slow in processing.
Standard attention typically follows the pattern:
- Load keys, queries, values from HBM -> GPU on-chip SRAM
- Calculates a single step of attention, writes results -> HBM
- Repeats for each attention step
We can optimize the model by using Flash attention
Flash attention uses SRAM, which is smaller in memory, and faster in processing. This method:
- Loads keys, queries and values once
- Fuses the operations of the attention mechanism
- Writes back to HBM
Let’s use the torch.nn.functional.scaled_dot_product_attention
function, optimized for GPUs. This method uses the Flash Attention algorithm when available. For more on this mechanism, see the research paper.
%%writefile profiler.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class OptimizedAttention(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.query = nn.Linear(embed_dim, embed_dim)
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
self.scale = embed_dim ** -0.5
def forward(self, x):
= self.query(x)
q = self.key(x)
k = self.value(x)
v
return F.scaled_dot_product_attention(q, k, v, scale=self.scale)
# Update the SimpleTransformer class to use OptimizedAttention
class OptimizedTransformer(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.attention = OptimizedAttention(embed_dim)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
= self.attention(x)
attn_output return self.norm(x + attn_output)
# Create a model and sample input
= 256
embed_dim = 1000
seq_length = 32
batch_size
# Create a new model with the optimized attention
= OptimizedTransformer(embed_dim, num_heads=1).cuda()
optimized_model = torch.randn(batch_size, seq_length, embed_dim).cuda()
sample_input
import torch.cuda.profiler as profiler
# Warm-up run
optimized_model(sample_input)
# Profile the optimized model
with profiler.profile(activities=[profiler.ProfilerActivity.CPU, profiler.ProfilerActivity.CUDA], record_shapes=True, profile_memory=True, with_stack=True) as prof:
with profiler.record_function("optimized_model_inference"):
optimized_model(sample_input)
# Print profiling results
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
Overwriting profiler.py
!nsys profile --stats=true python profiler.py
Collecting data...
Traceback (most recent call last):
File "/content/profiler.py", line 48, in <module>
with profiler.profile(activities=[profiler.ProfilerActivity.CPU, profiler.ProfilerActivity.CUDA], record_shapes=True, profile_memory=True, with_stack=True) as prof:
AttributeError: module 'torch.cuda.profiler' has no attribute 'ProfilerActivity'
Generating '/tmp/nsys-report-ed31.qdstrm'
[1/8] [========================100%] report2.nsys-rep
[2/8] [========================100%] report2.sqlite
[3/8] Executing 'nvtx_sum' stats report
SKIPPED: /content/report2.sqlite does not contain NV Tools Extension (NVTX) data.
[4/8] Executing 'osrt_sum' stats report
Time (%) Total Time (ns) Num Calls Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- ------------ ----------- --------- ----------- ------------ ---------------------
79.1 501,166,791 17 29,480,399.5 2,985,176.0 3,734 100,157,864 38,521,524.2 poll
11.6 73,748,312 639 115,412.1 12,665.0 1,374 18,555,420 1,004,303.4 ioctl
2.9 18,137,686 5,796 3,129.3 2,490.0 1,000 47,046 2,214.1 stat64
1.6 9,881,832 6,366 1,552.3 1,470.0 1,000 21,190 685.9 lstat64
1.3 8,400,908 1,104 7,609.5 2,719.5 1,007 251,640 16,190.3 read
1.2 7,379,916 992 7,439.4 6,844.5 1,883 28,734 2,950.8 open64
0.8 5,069,938 1 5,069,938.0 5,069,938.0 5,069,938 5,069,938 0.0 nanosleep
0.6 3,616,562 76 47,586.3 10,088.5 4,364 2,222,049 254,372.7 mmap64
0.5 3,406,279 1,852 1,839.2 1,636.5 1,070 49,459 1,738.7 fstat64
0.1 906,478 9 100,719.8 67,874.0 41,581 396,079 111,299.9 sem_timedwait
0.1 354,877 74 4,795.6 3,567.5 1,586 25,689 4,162.3 fopen
0.0 287,427 27 10,645.4 7,683.0 2,462 61,914 11,272.9 mmap
0.0 252,649 8 31,581.1 32,863.0 12,821 39,628 8,314.4 fgets
0.0 221,124 3 73,708.0 71,618.0 68,171 81,335 6,826.3 sleep
0.0 157,808 16 9,863.0 8,608.0 4,446 27,809 5,358.3 munmap
0.0 141,700 69 2,053.6 1,431.0 1,007 14,453 1,763.4 fclose
0.0 123,084 16 7,692.8 5,635.5 1,063 22,553 6,244.1 write
0.0 113,446 2 56,723.0 56,723.0 52,791 60,655 5,560.7 pthread_create
0.0 80,095 15 5,339.7 4,115.0 1,789 15,904 3,756.7 open
0.0 29,997 5 5,999.4 3,918.0 1,410 16,372 5,916.5 fopen64
0.0 25,575 2 12,787.5 12,787.5 9,492 16,083 4,660.5 socket
0.0 13,661 5 2,732.2 1,858.0 1,045 6,703 2,274.3 pthread_cond_signal
0.0 12,962 1 12,962.0 12,962.0 12,962 12,962 0.0 connect
0.0 12,710 2 6,355.0 6,355.0 1,101 11,609 7,430.3 pthread_mutex_trylock
0.0 10,017 1 10,017.0 10,017.0 10,017 10,017 0.0 pipe2
0.0 6,291 1 6,291.0 6,291.0 6,291 6,291 0.0 getc
0.0 4,856 1 4,856.0 4,856.0 4,856 4,856 0.0 fread
0.0 3,791 3 1,263.7 1,188.0 1,062 1,541 248.3 fcntl
0.0 3,692 3 1,230.7 1,150.0 1,078 1,464 205.3 fflush
0.0 3,620 2 1,810.0 1,810.0 1,524 2,096 404.5 sigaction
0.0 2,106 1 2,106.0 2,106.0 2,106 2,106 0.0 fputs_unlocked
0.0 1,780 1 1,780.0 1,780.0 1,780 1,780 0.0 bind
0.0 1,637 1 1,637.0 1,637.0 1,637 1,637 0.0 fcntl64
0.0 1,376 1 1,376.0 1,376.0 1,376 1,376 0.0 listen
[5/8] Executing 'cuda_api_sum' stats report
Time (%) Total Time (ns) Num Calls Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- ----------- ----------- --------- ---------- ------------ ----------------------------
64.9 73,600,176 10 7,360,017.6 25,209.5 13,686 29,189,994 12,030,081.2 cudaLaunchKernel
19.6 22,201,745 9 2,466,860.6 11,600.0 3,985 15,341,866 5,310,374.7 cudaMemcpyAsync
13.2 15,018,628 2 7,509,314.0 7,509,314.0 2,117,769 12,900,859 7,624,796.1 cudaFree
2.0 2,217,914 13 170,608.8 172,874.0 4,544 329,658 81,216.7 cudaMalloc
0.2 222,195 1,149 193.4 165.0 89 1,779 108.2 cuGetProcAddress_v2
0.2 173,034 9 19,226.0 6,419.0 5,973 55,838 17,545.3 cudaStreamSynchronize
0.0 25,834 18 1,435.2 371.0 354 15,734 3,605.2 cudaEventCreateWithFlags
0.0 24,330 10 2,433.0 2,015.5 1,109 6,827 1,643.2 cudaStreamIsCapturing_v10000
0.0 4,335 3 1,445.0 1,419.0 1,209 1,707 250.0 cuInit
0.0 2,221 4 555.3 217.5 150 1,636 721.3 cuModuleGetLoadingMode
[6/8] Executing 'cuda_gpu_kern_sum' stats report
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- ----------- ----------- --------- --------- ----------- ----------------------------------------------------------------------------------------------------
57.5 11,818,611 4 2,954,652.8 1,712,425.5 1,701,017 6,692,743 2,492,068.0 volta_sgemm_128x64_tn
28.4 5,834,459 1 5,834,459.0 5,834,459.0 5,834,459 5,834,459 0.0 volta_sgemm_128x64_nn
5.8 1,197,893 1 1,197,893.0 1,197,893.0 1,197,893 1,197,893 0.0 void <unnamed>::softmax_warp_forward<float, float, float, (int)10, (bool)0, (bool)0>(T2 *, const T1…
3.7 763,694 1 763,694.0 763,694.0 763,694 763,694 0.0 void at::native::<unnamed>::vectorized_layer_norm_kernel<float, float>(int, T2, const T1 *, const T…
2.6 535,827 2 267,913.5 267,913.5 263,226 272,601 6,629.1 void at::native::vectorized_elementwise_kernel<(int)4, at::native::AUnaryFunctor<float, float, floa…
1.9 389,943 1 389,943.0 389,943.0 389,943 389,943 0.0 void at::native::vectorized_elementwise_kernel<(int)4, at::native::CUDAFunctor_add<float>, at::deta…
[7/8] Executing 'cuda_gpu_mem_time_sum' stats report
Time (%) Total Time (ns) Count Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Operation
-------- --------------- ----- --------- -------- -------- --------- ----------- ----------------------------
100.0 6,588,201 9 732,022.3 736.0 704 6,510,028 2,166,783.8 [CUDA memcpy Host-to-Device]
[8/8] Executing 'cuda_gpu_mem_size_sum' stats report
Total (MB) Count Avg (MB) Med (MB) Min (MB) Max (MB) StdDev (MB) Operation
---------- ----- -------- -------- -------- -------- ----------- ----------------------------
33.560 9 3.729 0.001 0.001 32.768 10.890 [CUDA memcpy Host-to-Device]
Generated:
/content/report2.nsys-rep
/content/report2.sqlite
(Numbers will differ slightly each time we run these cells)
Looking at the “cuda_gpu_kern_sum” report, we notice:
- volta_sgemm_128x64_tn (58.4% of GPU time, previously 61.9%):
- We see a slight decrease in what is likely the matrix multiplication for computing attention weights. Though small on some tiny sample data, imagine these gains multiplied exponentially on real world training and inference involving text, images, video etc.
- volta_sgemm_64x64_tn (13.3%, previously 14.2%):
- Final matrix multiplication with the value matrix.
- volta_sgemm_128x64_nn (10.6%, previously 11.3%):
- The linear layer matrix multiplications.
- vectorized_layer_norm_kernel (7.7%, previously 8.1%):
- This corresponds to the LayerNorm operation in the SimpleTransformer class.
- vectorized_elementwise_kernel (4.7% + 3.9% = 8.6%, previously 3.1%):
- This now appears as two separate kernels, possibly for different elementwise operations.
- softmax_warp_forward (1.3%, previously 1.4%):
- This is still the softmax operation applied to the attention weights.