4  Profiling and optimizing PyTorch training

Open In Colab

!pip install jupyterlab-nvidia-nsight

(Make sure you are using the free T4 runtime in Colab)

Since using GPUs is the most expensive step in ML training and inference, no small amount of work goes into optimizing their use. In the real world, very few organizations and developers work on low-level kernel optimizations. They typically work further up the stack with frameworks such as PyTorch, leaving PyTorch’s optimizations to those working on its backend (which of course uses CUDA).

To give us a lens into the operations being performed on the accelerator and their efficiency in this scenario, there are a variety of profiling tools available. In this notebook, we will explore the use of Nvidia’s Nsight. The software is available as a desktop application and command line tool.

4.0.1 Install Nsight tools

Since Colabs are essentially a linux-based virtual machine, we can use apt get to install the Nvidia tools

%%bash

apt update
apt install -y --no-install-recommends gnupg
echo "deb http://developer.download.nvidia.com/devtools/repos/ubuntu$(source /etc/lsb-release; echo "$DISTRIB_RELEASE" | tr -d .)/$(dpkg --print-architecture) /" | tee /etc/apt/sources.list.d/nvidia-devtools.list
apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
apt update
apt install nsight-systems-cli

4.0.2 Check the installation

!nsys status -e
Timestamp counter supported: Yes

CPU Profiling Environment Check
Root privilege: enabled
Linux Kernel Paranoid Level = 2
Linux Distribution = Ubuntu
Linux Kernel Version = 6.1.85+: OK
Linux perf_event_open syscall available: OK
Sampling trigger event available: OK
Intel(c) Last Branch Record support: Not Available
CPU Profiling Environment (process-tree): OK
CPU Profiling Environment (system-wide): OK

See the product documentation at https://docs.nvidia.com/nsight-systems for more information,
including information on how to set the Linux Kernel Paranoid Level.

4.0.3 Simple attention

Here’s our basic attention mechanism that computes query, key, and value matrices to generate weighted representations of input data. The SimpleTransformer class combines this attention mechanism with layer normalization in a residual connection setup.

We will include profiling code to measure CPU and GPU performance metrics when running the model on sample input data.

%%writefile profiler.py

import torch
import torch.nn as nn

class SimpleAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        attn_weights = torch.matmul(q, k.transpose(-2, -1))
        attn_weights = torch.softmax(attn_weights, dim=-1)

        return torch.matmul(attn_weights, v)

class SimpleTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attention = SimpleAttention(embed_dim)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        attn_output = self.attention(x)
        return self.norm(x + attn_output)

# Create a model and sample input
embed_dim = 256
seq_length = 100
batch_size = 32

model = SimpleTransformer(embed_dim, num_heads=1).cuda()
sample_input = torch.randn(batch_size, seq_length, embed_dim).cuda()

import torch.cuda.profiler as profiler

# Warm-up run
model(sample_input)

# Profile the model
with profiler.profile(activities=[profiler.ProfilerActivity.CPU, profiler.ProfilerActivity.CUDA], record_shapes=True, profile_memory=True, with_stack=True) as prof:
    with profiler.record_function("model_inference"):
        model(sample_input)

# Print profiling results
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
Writing profiler.py
!nsys profile --stats=true python profiler.py
Collecting data...
Traceback (most recent call last):
  File "/content/profiler.py", line 46, in <module>
    with profiler.profile(activities=[profiler.ProfilerActivity.CPU, profiler.ProfilerActivity.CUDA], record_shapes=True, profile_memory=True, with_stack=True) as prof:
AttributeError: module 'torch.cuda.profiler' has no attribute 'ProfilerActivity'
Generating '/tmp/nsys-report-4b28.qdstrm'
[1/8] [========================100%] report1.nsys-rep
[2/8] [========================100%] report1.sqlite
[3/8] Executing 'nvtx_sum' stats report
SKIPPED: /content/report1.sqlite does not contain NV Tools Extension (NVTX) data.
[4/8] Executing 'osrt_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)      Med (ns)    Min (ns)    Max (ns)    StdDev (ns)            Name         
 --------  ---------------  ---------  ------------  ------------  ---------  -----------  ------------  ----------------------
     76.0    1,702,896,538         29  58,720,570.3  77,869,201.0      4,006  100,150,826  44,495,375.6  poll                  
     13.2      295,663,488      1,672     176,832.2       3,329.0      1,004   15,374,049     641,970.2  read                  
      4.0       89,772,667        611     146,927.4      16,308.0      1,509   23,241,632   1,163,272.8  ioctl                 
      3.9       86,647,314      5,802      14,934.0       3,006.5      1,006   13,686,743     246,663.1  stat64                
      1.7       38,487,406        992      38,797.8      13,551.0      2,049      844,849     101,781.4  open64                
      0.5       10,831,409      7,221       1,500.0       1,469.0      1,000       18,716         610.5  lstat64               
      0.3        6,090,781          1   6,090,781.0   6,090,781.0  6,090,781    6,090,781           0.0  nanosleep             
      0.2        3,767,157         76      49,567.9      10,444.5      3,410    2,163,110     247,130.7  mmap64                
      0.2        3,716,694      1,853       2,005.8       1,794.0      1,018       25,280       1,277.1  fstat64               
      0.0          596,671          4     149,167.8      49,473.0     35,192      462,533     209,267.0  sem_timedwait         
      0.0          521,407         74       7,046.0       3,520.5      1,697      169,539      19,585.7  fopen                 
      0.0          333,517         20      16,675.8       9,374.5      2,751      105,415      22,298.1  mmap                  
      0.0          303,794         16      18,987.1      12,165.5      1,402       57,334      18,133.4  write                 
      0.0          261,847          8      32,730.9      32,683.5     23,171       39,984       5,418.0  fgets                 
      0.0          207,125          3      69,041.7      68,447.0     66,655       72,023       2,733.0  sleep                 
      0.0          142,461         70       2,035.2       1,419.5      1,079       13,265       1,673.8  fclose                
      0.0          118,808          2      59,404.0      59,404.0     56,211       62,597       4,515.6  pthread_create        
      0.0          114,648          9      12,738.7      13,780.0      7,881       21,544       4,325.4  munmap                
      0.0           92,736         12       7,728.0       4,024.0      1,623       27,163       8,106.5  pthread_cond_signal   
      0.0           85,302         15       5,686.8       4,267.0      2,028       22,918       5,157.5  open                  
      0.0           33,190          5       6,638.0       5,099.0      1,635       16,932       5,952.6  fopen64               
      0.0           23,712          2      11,856.0      11,856.0      8,025       15,687       5,417.9  socket                
      0.0           10,493          1      10,493.0      10,493.0     10,493       10,493           0.0  connect               
      0.0            8,915          1       8,915.0       8,915.0      8,915        8,915           0.0  pthread_cond_broadcast
      0.0            7,372          1       7,372.0       7,372.0      7,372        7,372           0.0  pipe2                 
      0.0            6,383          1       6,383.0       6,383.0      6,383        6,383           0.0  getc                  
      0.0            5,305          1       5,305.0       5,305.0      5,305        5,305           0.0  fread                 
      0.0            5,009          2       2,504.5       2,504.5      1,829        3,180         955.3  sigaction             
      0.0            4,778          4       1,194.5       1,195.0      1,043        1,345         123.3  fflush                
      0.0            3,201          3       1,067.0       1,055.0      1,033        1,113          41.3  fcntl                 
      0.0            2,580          1       2,580.0       2,580.0      2,580        2,580           0.0  fputs_unlocked        
      0.0            2,300          1       2,300.0       2,300.0      2,300        2,300           0.0  bind                  
      0.0            1,609          1       1,609.0       1,609.0      1,609        1,609           0.0  fcntl64               
      0.0            1,505          1       1,505.0       1,505.0      1,505        1,505           0.0  listen                
      0.0            1,318          1       1,318.0       1,318.0      1,318        1,318           0.0  pthread_mutex_trylock 

[5/8] Executing 'cuda_api_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)      Med (ns)    Min (ns)    Max (ns)   StdDev (ns)               Name            
 --------  ---------------  ---------  ------------  ------------  ---------  ----------  ------------  ----------------------------
     59.0      133,942,988          8  16,742,873.5      44,480.0     18,304  75,424,540  29,365,709.9  cudaLaunchKernel            
     19.9       45,123,527          9   5,013,725.2      21,450.0      7,039  44,149,269  14,677,890.0  cudaMemcpyAsync             
     19.1       43,452,758          2  21,726,379.0  21,726,379.0  5,346,866  38,105,892  23,164,129.4  cudaFree                    
      0.8        1,898,877          6     316,479.5     216,731.0     12,598     976,777     341,108.4  cudaMalloc                  
      0.6        1,421,966         18      78,998.1         668.0        616   1,401,568     330,071.6  cudaEventCreateWithFlags    
      0.3          658,905          3     219,635.0       3,212.0      2,603     653,090     375,383.2  cudaStreamIsCapturing_v10000
      0.2          469,330      1,149         408.5         256.0        126     146,038       4,301.4  cuGetProcAddress_v2         
      0.1          182,122          9      20,235.8       7,191.0      5,845      68,490      20,841.1  cudaStreamSynchronize       
      0.0            7,048          3       2,349.3       2,324.0      2,234       2,490         129.9  cuInit                      
      0.0            1,790          4         447.5         300.5        252         937         327.1  cuModuleGetLoadingMode      

[6/8] Executing 'cuda_gpu_kern_sum' stats report

 Time (%)  Total Time (ns)  Instances  Avg (ns)   Med (ns)   Min (ns)  Max (ns)  StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  ---------  ---------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     61.8          632,402          3  210,800.7  210,491.0   209,947   211,964      1,043.5  volta_sgemm_128x64_tn                                                                               
     14.3          146,205          1  146,205.0  146,205.0   146,205   146,205          0.0  volta_sgemm_64x64_tn                                                                                
     11.2          114,973          1  114,973.0  114,973.0   114,973   114,973          0.0  volta_sgemm_128x64_nn                                                                               
      8.1           82,590          1   82,590.0   82,590.0    82,590    82,590          0.0  void at::native::<unnamed>::vectorized_layer_norm_kernel<float, float>(int, T2, const T1 *, const T…
      3.1           32,127          1   32,127.0   32,127.0    32,127    32,127          0.0  void at::native::vectorized_elementwise_kernel<(int)4, at::native::CUDAFunctor_add<float>, at::deta…
      1.4           14,240          1   14,240.0   14,240.0    14,240    14,240          0.0  void <unnamed>::softmax_warp_forward<float, float, float, (int)7, (bool)0, (bool)0>(T2 *, const T1 …

[7/8] Executing 'cuda_gpu_mem_time_sum' stats report

 Time (%)  Total Time (ns)  Count  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)           Operation          
 --------  ---------------  -----  --------  --------  --------  --------  -----------  ----------------------------
    100.0          675,793      9  75,088.1     768.0       735   599,634    197,031.0  [CUDA memcpy Host-to-Device]

[8/8] Executing 'cuda_gpu_mem_size_sum' stats report

 Total (MB)  Count  Avg (MB)  Med (MB)  Min (MB)  Max (MB)  StdDev (MB)           Operation          
 ----------  -----  --------  --------  --------  --------  -----------  ----------------------------
      4.068      9     0.452     0.001     0.001     3.277        1.067  [CUDA memcpy Host-to-Device]

Generated:
    /content/report1.nsys-rep
    /content/report1.sqlite

(Numbers will differ slightly each time we run these cells)

Let’s analyze the “cuda_gpu_kern_sum” report, which shows the GPU kernel executions:

  • volta_sgemm_128x64_tn (61.9% of GPU time): This is likely the matrix multiplication for computing attention weights (q * k.transpose(-2, -1)). It’s using NVIDIA’s optimized GEMM (General Matrix Multiplication) kernel. Typically the most compute-intensive operation in a. transformer model.

  • volta_sgemm_64x64_tn (14.2% of GPU time): This could be another part of the attention computation, possibly the final matrix multiplication with the value matrix (attn_weights * v).

  • volta_sgemm_128x64_nn (11.3% of GPU time): This might be the matrix multiplication in one of the linear layers (query, key, or value projection).

  • vectorized_layer_norm_kernel (8.1% of GPU time): This corresponds to the LayerNorm operation in the SimpleTransformer class. vectorized_elementwise_kernel (3.1% of GPU time): This could be the element-wise addition in the residual connection (x + attn_output).

  • softmax_warp_forward (1.4% of GPU time): This is the softmax operation applied to the attention weights.

The SimpleAttention class operations are primarily represented by items 1, 2, 3, and 6 in this list. These operations account for about 88.8% of the GPU kernel execution time, which indicates that the attention mechanism is indeed a significant part of the computation. To optimize this, we could:

  • Use the optimized attention mechanism as suggested in the tutorial (torch.nn.functional.scaled_dot_product_attention).
  • Experiment with different batch sizes or sequence lengths to find the optimal configuration for your hardware.
  • Consider using mixed precision (float16).

4.1 Flash Attention

Transformer models can be bottlenecked by self-attention, a mechanism which has quadratic time and memory complexity.

  • Standard attention uses High Bandwidth Memory (HBM) to store, read and write keys, queries and values

  • HBM typically has large memory and is slow in processing.

  • Standard attention typically follows the pattern:

    • Load keys, queries, values from HBM -> GPU on-chip SRAM
    • Calculates a single step of attention, writes results -> HBM
    • Repeats for each attention step
  • We can optimize the model by using Flash attention

  • Flash attention uses SRAM, which is smaller in memory, and faster in processing. This method:

    • Loads keys, queries and values once
    • Fuses the operations of the attention mechanism
    • Writes back to HBM

Let’s use the torch.nn.functional.scaled_dot_product_attention function, optimized for GPUs. This method uses the Flash Attention algorithm when available. For more on this mechanism, see the research paper.

%%writefile profiler.py

import torch
import torch.nn as nn

import torch.nn.functional as F

class OptimizedAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.scale = embed_dim ** -0.5

    def forward(self, x):
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        return F.scaled_dot_product_attention(q, k, v, scale=self.scale)

# Update the SimpleTransformer class to use OptimizedAttention
class OptimizedTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attention = OptimizedAttention(embed_dim)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        attn_output = self.attention(x)
        return self.norm(x + attn_output)

# Create a model and sample input
embed_dim = 256
seq_length = 1000
batch_size = 32

# Create a new model with the optimized attention
optimized_model = OptimizedTransformer(embed_dim, num_heads=1).cuda()
sample_input = torch.randn(batch_size, seq_length, embed_dim).cuda()

import torch.cuda.profiler as profiler

# Warm-up run
optimized_model(sample_input)

# Profile the optimized model
with profiler.profile(activities=[profiler.ProfilerActivity.CPU, profiler.ProfilerActivity.CUDA], record_shapes=True, profile_memory=True, with_stack=True) as prof:
    with profiler.record_function("optimized_model_inference"):
        optimized_model(sample_input)

# Print profiling results
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
Overwriting profiler.py
!nsys profile --stats=true python profiler.py
Collecting data...
Traceback (most recent call last):
  File "/content/profiler.py", line 48, in <module>
    with profiler.profile(activities=[profiler.ProfilerActivity.CPU, profiler.ProfilerActivity.CUDA], record_shapes=True, profile_memory=True, with_stack=True) as prof:
AttributeError: module 'torch.cuda.profiler' has no attribute 'ProfilerActivity'
Generating '/tmp/nsys-report-ed31.qdstrm'
[1/8] [========================100%] report2.nsys-rep
[2/8] [========================100%] report2.sqlite
[3/8] Executing 'nvtx_sum' stats report
SKIPPED: /content/report2.sqlite does not contain NV Tools Extension (NVTX) data.
[4/8] Executing 'osrt_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)     Med (ns)    Min (ns)    Max (ns)    StdDev (ns)           Name         
 --------  ---------------  ---------  ------------  -----------  ---------  -----------  ------------  ---------------------
     79.1      501,166,791         17  29,480,399.5  2,985,176.0      3,734  100,157,864  38,521,524.2  poll                 
     11.6       73,748,312        639     115,412.1     12,665.0      1,374   18,555,420   1,004,303.4  ioctl                
      2.9       18,137,686      5,796       3,129.3      2,490.0      1,000       47,046       2,214.1  stat64               
      1.6        9,881,832      6,366       1,552.3      1,470.0      1,000       21,190         685.9  lstat64              
      1.3        8,400,908      1,104       7,609.5      2,719.5      1,007      251,640      16,190.3  read                 
      1.2        7,379,916        992       7,439.4      6,844.5      1,883       28,734       2,950.8  open64               
      0.8        5,069,938          1   5,069,938.0  5,069,938.0  5,069,938    5,069,938           0.0  nanosleep            
      0.6        3,616,562         76      47,586.3     10,088.5      4,364    2,222,049     254,372.7  mmap64               
      0.5        3,406,279      1,852       1,839.2      1,636.5      1,070       49,459       1,738.7  fstat64              
      0.1          906,478          9     100,719.8     67,874.0     41,581      396,079     111,299.9  sem_timedwait        
      0.1          354,877         74       4,795.6      3,567.5      1,586       25,689       4,162.3  fopen                
      0.0          287,427         27      10,645.4      7,683.0      2,462       61,914      11,272.9  mmap                 
      0.0          252,649          8      31,581.1     32,863.0     12,821       39,628       8,314.4  fgets                
      0.0          221,124          3      73,708.0     71,618.0     68,171       81,335       6,826.3  sleep                
      0.0          157,808         16       9,863.0      8,608.0      4,446       27,809       5,358.3  munmap               
      0.0          141,700         69       2,053.6      1,431.0      1,007       14,453       1,763.4  fclose               
      0.0          123,084         16       7,692.8      5,635.5      1,063       22,553       6,244.1  write                
      0.0          113,446          2      56,723.0     56,723.0     52,791       60,655       5,560.7  pthread_create       
      0.0           80,095         15       5,339.7      4,115.0      1,789       15,904       3,756.7  open                 
      0.0           29,997          5       5,999.4      3,918.0      1,410       16,372       5,916.5  fopen64              
      0.0           25,575          2      12,787.5     12,787.5      9,492       16,083       4,660.5  socket               
      0.0           13,661          5       2,732.2      1,858.0      1,045        6,703       2,274.3  pthread_cond_signal  
      0.0           12,962          1      12,962.0     12,962.0     12,962       12,962           0.0  connect              
      0.0           12,710          2       6,355.0      6,355.0      1,101       11,609       7,430.3  pthread_mutex_trylock
      0.0           10,017          1      10,017.0     10,017.0     10,017       10,017           0.0  pipe2                
      0.0            6,291          1       6,291.0      6,291.0      6,291        6,291           0.0  getc                 
      0.0            4,856          1       4,856.0      4,856.0      4,856        4,856           0.0  fread                
      0.0            3,791          3       1,263.7      1,188.0      1,062        1,541         248.3  fcntl                
      0.0            3,692          3       1,230.7      1,150.0      1,078        1,464         205.3  fflush               
      0.0            3,620          2       1,810.0      1,810.0      1,524        2,096         404.5  sigaction            
      0.0            2,106          1       2,106.0      2,106.0      2,106        2,106           0.0  fputs_unlocked       
      0.0            1,780          1       1,780.0      1,780.0      1,780        1,780           0.0  bind                 
      0.0            1,637          1       1,637.0      1,637.0      1,637        1,637           0.0  fcntl64              
      0.0            1,376          1       1,376.0      1,376.0      1,376        1,376           0.0  listen               

[5/8] Executing 'cuda_api_sum' stats report

 Time (%)  Total Time (ns)  Num Calls   Avg (ns)     Med (ns)    Min (ns)    Max (ns)   StdDev (ns)               Name            
 --------  ---------------  ---------  -----------  -----------  ---------  ----------  ------------  ----------------------------
     64.9       73,600,176         10  7,360,017.6     25,209.5     13,686  29,189,994  12,030,081.2  cudaLaunchKernel            
     19.6       22,201,745          9  2,466,860.6     11,600.0      3,985  15,341,866   5,310,374.7  cudaMemcpyAsync             
     13.2       15,018,628          2  7,509,314.0  7,509,314.0  2,117,769  12,900,859   7,624,796.1  cudaFree                    
      2.0        2,217,914         13    170,608.8    172,874.0      4,544     329,658      81,216.7  cudaMalloc                  
      0.2          222,195      1,149        193.4        165.0         89       1,779         108.2  cuGetProcAddress_v2         
      0.2          173,034          9     19,226.0      6,419.0      5,973      55,838      17,545.3  cudaStreamSynchronize       
      0.0           25,834         18      1,435.2        371.0        354      15,734       3,605.2  cudaEventCreateWithFlags    
      0.0           24,330         10      2,433.0      2,015.5      1,109       6,827       1,643.2  cudaStreamIsCapturing_v10000
      0.0            4,335          3      1,445.0      1,419.0      1,209       1,707         250.0  cuInit                      
      0.0            2,221          4        555.3        217.5        150       1,636         721.3  cuModuleGetLoadingMode      

[6/8] Executing 'cuda_gpu_kern_sum' stats report

 Time (%)  Total Time (ns)  Instances   Avg (ns)     Med (ns)    Min (ns)   Max (ns)   StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  -----------  -----------  ---------  ---------  -----------  ----------------------------------------------------------------------------------------------------
     57.5       11,818,611          4  2,954,652.8  1,712,425.5  1,701,017  6,692,743  2,492,068.0  volta_sgemm_128x64_tn                                                                               
     28.4        5,834,459          1  5,834,459.0  5,834,459.0  5,834,459  5,834,459          0.0  volta_sgemm_128x64_nn                                                                               
      5.8        1,197,893          1  1,197,893.0  1,197,893.0  1,197,893  1,197,893          0.0  void <unnamed>::softmax_warp_forward<float, float, float, (int)10, (bool)0, (bool)0>(T2 *, const T1…
      3.7          763,694          1    763,694.0    763,694.0    763,694    763,694          0.0  void at::native::<unnamed>::vectorized_layer_norm_kernel<float, float>(int, T2, const T1 *, const T…
      2.6          535,827          2    267,913.5    267,913.5    263,226    272,601      6,629.1  void at::native::vectorized_elementwise_kernel<(int)4, at::native::AUnaryFunctor<float, float, floa…
      1.9          389,943          1    389,943.0    389,943.0    389,943    389,943          0.0  void at::native::vectorized_elementwise_kernel<(int)4, at::native::CUDAFunctor_add<float>, at::deta…

[7/8] Executing 'cuda_gpu_mem_time_sum' stats report

 Time (%)  Total Time (ns)  Count  Avg (ns)   Med (ns)  Min (ns)  Max (ns)   StdDev (ns)           Operation          
 --------  ---------------  -----  ---------  --------  --------  ---------  -----------  ----------------------------
    100.0        6,588,201      9  732,022.3     736.0       704  6,510,028  2,166,783.8  [CUDA memcpy Host-to-Device]

[8/8] Executing 'cuda_gpu_mem_size_sum' stats report

 Total (MB)  Count  Avg (MB)  Med (MB)  Min (MB)  Max (MB)  StdDev (MB)           Operation          
 ----------  -----  --------  --------  --------  --------  -----------  ----------------------------
     33.560      9     3.729     0.001     0.001    32.768       10.890  [CUDA memcpy Host-to-Device]

Generated:
    /content/report2.nsys-rep
    /content/report2.sqlite

(Numbers will differ slightly each time we run these cells)

Looking at the “cuda_gpu_kern_sum” report, we notice:

  • volta_sgemm_128x64_tn (58.4% of GPU time, previously 61.9%):
    • We see a slight decrease in what is likely the matrix multiplication for computing attention weights. Though small on some tiny sample data, imagine these gains multiplied exponentially on real world training and inference involving text, images, video etc.
  • volta_sgemm_64x64_tn (13.3%, previously 14.2%):
    • Final matrix multiplication with the value matrix.
  • volta_sgemm_128x64_nn (10.6%, previously 11.3%):
    • The linear layer matrix multiplications.
  • vectorized_layer_norm_kernel (7.7%, previously 8.1%):
    • This corresponds to the LayerNorm operation in the SimpleTransformer class.
  • vectorized_elementwise_kernel (4.7% + 3.9% = 8.6%, previously 3.1%):
    • This now appears as two separate kernels, possibly for different elementwise operations.
  • softmax_warp_forward (1.3%, previously 1.4%):
    • This is still the softmax operation applied to the attention weights.