MENU

Computer Architecture Labs

ハルビン工業大学(深圳)• 2024 • コンピュータ・アーキテクチャ Lab• における解決策 • HITSZ 计算机体系结构实验 2024

御質問が御座いましたら、このページの下部にあるコメント欄を御利用ください。
仰せ事有之候ハハ此丁之下ニアル意見之欄ヲ用ヰ給ヘ

千秋

次のLabには、x86アセンブリやCUDAプログラミングの知識が必要

日本語訳版

目次

Lab1 行列乗算

環境設定と事前テスト

このタスクの目的は、文字列をコンソールに出力することです。

まず、与えられたコードを確認しましょう。小タスクは2つあります。最初のタスクは、すべての桁が出力されるかどうかを確認することです。数値を分割することで桁を出力していることがわかります。10で割ると、その余りが出力する桁となり、それが rdx に保存されます。一方、商は次のループのために rax に保存されます。

次に、最終的な割り算の後に rax が0になるとループが終了することがわかります。x86では、命令 jnz は Z フラグを確認してジャンプするかどうかを判断します。ループに影響を与えないように、rax を変更しないようにする必要があります。このため、$$\textbf{test}\text{ \%rax, \%rax}$$ という命令を使用します。この命令は rax 自体を変更せずに Z フラグを設定します。

次のタスクは、システムコールのエラーを修正することです。プロセスを慎重に確認した結果、一部のロジックが不要であり、一部のレジスタが間違って設定されていることがわかりました。

/**
 * 標準出力に64ビット整数を出力する
 */

.section .bss
//    .lcomm num, 8       // 64ビット整数を保存する領域
    .lcomm buffer, 21   // 20桁の数字 + 1つのヌル文字のための出力バッファ

.section .data
    newline: .byte 0xA      // 改行文字

.section .text
    .globl _start

_start:
    // 出力する数字を初期化
    mov $1234567890123456789, %rax
//    mov %rax, num(%rip)

    // 整数を文字列に変換
//    mov num(%rip), %rax
    lea buffer+20(%rip), %rdi   // 出力文字列の最後の文字のアドレスをrdiレジスタに格納
    movb $0, (%rdi)             // 最後の文字に'\0'をセットし、終了をマーク

convert_loop:                   // 整数を出力用文字列に変換するループ
//    mov %rax, %rdx
    xor %rdx, %rdx
    mov $10, %rcx
    div %rcx                    // rdx = rax % 10, rax = rax / 10
    add $'0', %dl               // 対応するASCIIコードを計算(rdxの下位8ビットをdlと呼ぶ)
    dec %rdi
    mov %dl, (%rdi)             // 結果をメモリに書き込む
                                // 終了条件を確認(TODO: 終了判定の指令を追加してください)
    jnz convert_loop

find_start:                     // 変換が完了したら、文字列の先頭にあるすべての'0'をスキップ
    cmpb $'0', (%rdi)
    jne print_string
    inc %rdi
    jmp find_start

print_string:                   // 文字列の出力を開始
                                // 文字列の長さを計算
    lea buffer+20(%rip), %rax
    sub %rdi, %rax              // 格納されたバイト数を計算
    mov %rax, %rdx              // 出力するバイト数をrdxに格納

    // システムコール番号 (sys_write)     // TODO: 正常に文字列を出力するためにここでシステムコールを修正してください
    mov $1, %rax
    // ファイル記述子 (stdout)
    mov $1, %rdi
    // 文字列のポインタを設定
    mov %rdi, %rsi
    // 書き込むバイト数
    mov %rdx, %rdx
    // システムコールを実行
    syscall

    // 改行文字を出力
    mov $1, %rax
    mov $1, %rdi
    lea newline(%rip), %rsi
    mov $1, %rdx
    syscall

    // プログラムを終了
    mov $60, %rax  // システムコール番号 (sys_exit)
    xor %rdi, %rdi // 終了コード
    syscall

出力

┌──(amamitsu㉿amamitsu)-[~/Applications/lab1/build]
└─$ ./dist/bins/lab1_print_integer   
1234567890123456789

行列乗算のコード補完

コードをざっと見てみると、単に A[$m$][$k$]、B[$k$][$n$]、C[$m$][$n$] のアドレスを正しく計算する必要があることが簡単にわかります。配列は物理的には線形的に保存されています。$m, n, k$ の次元が DIM_MDIM_NDIM_K と指定されている場合、線形アドレスは次のように解決できます。\begin{align*} \text{A}[m][k]&: \text{base\_A} + 4\times(m \times\text{DIM\_K} + k)\\ \text{B}[k][n]&: \text{base\_B} + 4\times(k \times\text{DIM\_N} + n)\\ \text{C}[m][n]&: \text{base\_C} + 4\times(m \times\text{DIM\_N} + n) \end{align*}

ここで、各単精度の数値は4バイトを占めるため、4を掛けています。

x86では、ループカウンタを使用して $m, n, k$ を読み込むことができます。配列のベースアドレスや DIM_MDIM_NDIM_K はマクロとして与えられています。

A[$m$][$k$] を例に取ると、次のようになります。\begin{align*} &\textbf{MOV} &\text{loop\_m, mat\_elem\_idx}\\ &\textbf{IMUL} &\text{DIM\_K, mat\_elem\_idx}\\ &\textbf{ADD} &\text{loop\_k, mat\_elem\_idx}\\ &\textbf{flds} &\text{(MAT\_A, mat\_elem\_idx, 4)} \end{align*}​

実際のところ、必要なのは要素の線形インデックスを解決することだけです。そして flds 命令が自動的に4を掛けて base_A を設定します。そのため、単に $m \times\text{DIM\_K} + k$ を解決するだけで十分です。他の要素も同様のロジックに従います。

.text;
.p2align 2;
.global gemm_kernel;
.type gemm_kernel, %function;

// 以下はマクロ定義
#define     MAT_C               %rdi    // 行列Cのアドレス
#define     MAT_A               %rsi    // 行列Aのアドレス
#define     MAT_B               %r14    // 行列Bのアドレス
#define     DIM_M               %rcx    // 行列Cの行数 (M)
#define     DIM_N               %r8     // 行列Cの列数 (N)
#define     DIM_K               %r9     // 共通次元 (K)
#define     loop_m              %r10    // M方向のループカウンタ
#define     loop_k              %r11    // K方向のループカウンタ
#define     loop_n              %r12    // N方向のループカウンタ
#define     mat_elem_idx        %r13    // 行列要素のインデックス計算用

.macro PUSHD                                        // レジスタ値を保存
    push %rax
    push %rbx
    push %rcx
    push %rdx
    push %rsi
    push %rdi
    push %rbp
    push %r8
    push %r9
    push %r10
    push %r11
    push %r12
    push %r13
    push %r14
    push %r15
.endm

.macro POPD                                        // レジスタ値を復元
    pop %r15
    pop %r14
    pop %r13
    pop %r12
    pop %r11
    pop %r10
    pop %r9
    pop %r8
    pop %rbp
    pop %rdi
    pop %rsi
    pop %rdx
    pop %rcx
    pop %rbx
    pop %rax
.endm

.macro GEMM_INIT                                   // 初期化
    // TODO: 行列BのアドレスをMAT_Bマクロに対応するレジスタに保存

    xor loop_m, loop_m                              // M方向ループカウンタをクリア
    xor loop_k, loop_k                              // K方向ループカウンタをクリア
    xor loop_n, loop_n                              // N方向ループカウンタをクリア
.endm

.macro DO_GEMM                                      // kij方式で行列積を計算
DO_LOOP_K:                                          // 最外層のK次元のループ
    xor loop_m, loop_m                              // M方向のループカウンタをクリア

DO_LOOP_M:                                          // M次元のループ
    xor loop_n, loop_n                              // N方向のループカウンタをクリア

    // TODO: A[m][k]を読み込む



    flds (MAT_A, mat_elem_idx, 4)                   // A[m][k]をst(0)にロード。fldsはデータをスタックトップst(0)にのみロード可能。
                                                    // 元のst(0)はst(1)に移動。スタックが満杯の場合、プッシュ失敗。

DO_LOOP_N:
    // TODO: B[k][n]を読み込む



    flds (MAT_B, mat_elem_idx, 4)                   // B[k][n]をロード

    fmul %st(1), %st(0)                             // A[m][k] * B[k][n]を計算

    // TODO: C[m][n]を読み込む



    flds (MAT_C, mat_elem_idx, 4)                   // C[m][n]をロード

    faddp %st(1), %st(0)                            // C[m][n] + A[m][k] * B[k][n]を計算
    fstps (MAT_C, mat_elem_idx, 4)                  // 結果をC[m][n]に書き戻す

    add $1, loop_n                                  // N方向のループカウンタをインクリメント
    cmp DIM_N, loop_n
    jl DO_LOOP_N

    fstp %st(0)                                     // st(0)をクリア。行列Aの要素はこれ以上使用しない
    add $1, loop_m                                  // M方向のループカウンタをインクリメント
    cmp DIM_M, loop_m
    jl DO_LOOP_M

    add $1, loop_k                                  // K方向のループカウンタをインクリメント
    cmp DIM_K, loop_k
    jl DO_LOOP_K
.endm

gemm_kernel:
    PUSHD                                           // レジスタ値を保存
    GEMM_INIT                                       // 初期化
    DO_GEMM                                         // 行列積の計算
    POPD                                            // レジスタ値を復元
    ret                                             // 関数から戻る

出力

┌──(amamitsu㉿amamitsu)-[~/Applications/lab1/build]
└─$ ./dist/bins/lab1_test_gemm_kernel.unittest --gtest_filter=gemm_kernel.test0
Running main() from /home/amamitsu/Applications/lab11/build/_deps/googletest-src/googletest/src/gtest_main.cc
Note: Google Test filter = gemm_kernel.test0
[==========] Running 1 test from 1 test suite.                                                                                                                                         
[----------] Global test environment set-up.
[----------] 1 test from gemm_kernel
[ RUN      ] gemm_kernel.test0
[       OK ] gemm_kernel.test0 (46 ms)
[----------] 1 test from gemm_kernel (46 ms total)

[----------] Global test environment tear-down
[==========] 1 test from 1 test suite ran. (46 ms total)
[  PASSED  ] 1 test.
                                                                                                                                                                                                                                            
┌──(amamitsu㉿amamitsu)-[~/Applications/lab1/build]
└─$ ./dist/bins/lab1_gemm 256 256 256 
GEMM performance info:
                M, K, N: 256, 256, 256
                Ops: 0.0335544
                Total compute time(s): 1.81176
                Cost(s): 0.00905878
                Benchmark(Gflops): 3.70408

CPUインフォ

CPUの基本情報を出力する、更に、L1データキャッシュ、L2、L3キャッシュの基本情報を出力するでいい。
この部分では、CPUの情報を確認するだけで十分です。情報は lscpu コマンドを使用して取得できます。

CPU 0 の L1D、L2、L3 キャッシュに関する基本的な情報については、ディレクトリを以下に変更します:\[\text{/sys/devices/system/cpu/cpu0/cache}\]

このディレクトリの下に、必要な情報を保存しているいくつかのサブディレクトリがあります。なお、L1D キャッシュは index0、L1I キャッシュは index1、L2 キャッシュは index2、L3 キャッシュは index3 です。ここでは、coherency_line_sizenumber_of_setsways_of_associativity といったパラメータを取得する必要があります。

┌──(amamitsu㉿amamitsu)-[~/Applications/lab1/build]
└─$ lscpu
Architecture:             x86_64
  CPU op-mode(s):         32-bit, 64-bit
  Address sizes:          45 bits physical, 48 bits virtual
  Byte Order:             Little Endian
CPU(s):                   16
  On-line CPU(s) list:    0-15
Vendor ID:                GenuineIntel
  Model name:             12th Gen Intel(R) Core(TM) i9-12900K
    CPU family:           6
    Model:                151
    Thread(s) per core:   1
    Core(s) per socket:   1
    Socket(s):            16
    Stepping:             2
    BogoMIPS:             6374.40
    Flags:                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid
                           tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erm
                          s invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni arat umip gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Virtualization features:  
  Hypervisor vendor:      VMware
  Virtualization type:    full
Caches (sum of all):      
  L1d:                    768 KiB (16 instances)
  L1i:                    512 KiB (16 instances)
  L2:                     20 MiB (16 instances)
  L3:                     480 MiB (16 instances)
NUMA:                     
  NUMA node(s):           1
  NUMA node0 CPU(s):      0-15
Vulnerabilities:          
  Gather data sampling:   Not affected
  Itlb multihit:          Not affected
  L1tf:                   Mitigation; PTE Inversion
  Mds:                    Mitigation; Clear CPU buffers; SMT Host state unknown
  Meltdown:               Mitigation; PTI
  Mmio stale data:        Not affected
  Reg file data sampling: Vulnerable: No microcode
  Retbleed:               Mitigation; IBRS
  Spec rstack overflow:   Not affected
  Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
  Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
  Spectre v2:             Mitigation; IBRS; IBPB conditional; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI SW loop, KVM SW loop
  Srbds:                  Not affected
  Tsx async abort:        Not affected
                                                                                                                                                                                                                                            
┌──(amamitsu㉿amamitsu)-[~/Applications/lab1/build]
└─$ cd /sys/devices/system/cpu/cpu0/cache      
                                                                                                                                                                                                                                            
┌──(amamitsu㉿amamitsu)-[/sys/…/system/cpu/cpu0/cache]
└─$ cd index0                            
                                                                                                                                                                                                                                            
┌──(amamitsu㉿amamitsu)-[/sys/…/cpu/cpu0/cache/index0]
└─$ cat coherency_line_size  
64
                                                                                                                                                                                                                                            
┌──(amamitsu㉿amamitsu)-[/sys/…/cpu/cpu0/cache/index0]
└─$ cat number_of_sets     
64
                                                                                                                                                                                                                                            
┌──(amamitsu㉿amamitsu)-[/sys/…/cpu/cpu0/cache/index0]
└─$ cat ways_of_associativity
12
                                                                                                                                                                                                                                            
┌──(amamitsu㉿amamitsu)-[/sys/…/cpu/cpu0/cache/index0]
└─$ cd ../index2             
                                                                                                                                                                                                                                            
┌──(amamitsu㉿amamitsu)-[/sys/…/cpu/cpu0/cache/index2]
└─$ cat coherency_line_size  
64
                                                                                                                                                                                                                                            
┌──(amamitsu㉿amamitsu)-[/sys/…/cpu/cpu0/cache/index2]
└─$ cat number_of_sets       
2048
                                                                                                                                                                                                                                            
┌──(amamitsu㉿amamitsu)-[/sys/…/cpu/cpu0/cache/index2]
└─$ cat ways_of_associativity
10
                                                                                                                                                                                                                                            
┌──(amamitsu㉿amamitsu)-[/sys/…/cpu/cpu0/cache/index2]
└─$ cd ../index3             
                                                                                                                                                                                                                                            
┌──(amamitsu㉿amamitsu)-[/sys/…/cpu/cpu0/cache/index3]
└─$ cat coherency_line_size  
64
                                                                                                                                                                                                                                            
┌──(amamitsu㉿amamitsu)-[/sys/…/cpu/cpu0/cache/index3]
└─$ cat number_of_sets       
40960
                                                                                                                                                                                                                                            
┌──(amamitsu㉿amamitsu)-[/sys/…/cpu/cpu0/cache/index3]
└─$ cat ways_of_associativity
12

Perfの使用

まず、perf list コマンドを使って、あらかじめ定義された perf イベントを確認します。 次に、perf を使用して行列積のパフォーマンスを調べることができます。以下に示します。

最初の結果は、一般的なパフォーマンスを示しています。特定のイベントを指定すると、それが何をテストしているかについてのパフォーマンスを確認できます。

┌──(amamitsu㉿amamitsu)-[/sys/…/cpu/cpu0/cache/index3]
└─$ perf list

List of pre-defined events (to be used in -e or -M):

  duration_time                                      [Tool event]
  user_time                                          [Tool event]
  system_time                                        [Tool event]
  mem-loads OR cpu_atom/mem-loads/                   [Kernel PMU event]
  mem-stores OR cpu_atom/mem-stores/                 [Kernel PMU event]
  ref-cycles OR cpu_atom/ref-cycles/                 [Kernel PMU event]
  topdown-bad-spec OR cpu_atom/topdown-bad-spec/     [Kernel PMU event]
  topdown-be-bound OR cpu_atom/topdown-be-bound/     [Kernel PMU event]
  topdown-fe-bound OR cpu_atom/topdown-fe-bound/     [Kernel PMU event]
  topdown-retiring OR cpu_atom/topdown-retiring/     [Kernel PMU event]
  mem-loads OR cpu_core/mem-loads/                   [Kernel PMU event]
  mem-loads-aux OR cpu_core/mem-loads-aux/           [Kernel PMU event]
  mem-stores OR cpu_core/mem-stores/                 [Kernel PMU event]
  ref-cycles OR cpu_core/ref-cycles/                 [Kernel PMU event]
  slots OR cpu_core/slots/                           [Kernel PMU event]
  topdown-bad-spec OR cpu_core/topdown-bad-spec/     [Kernel PMU event]
  topdown-be-bound OR cpu_core/topdown-be-bound/     [Kernel PMU event]
  topdown-br-mispredict OR cpu_core/topdown-br-mispredict/[Kernel PMU event]
  topdown-fe-bound OR cpu_core/topdown-fe-bound/     [Kernel PMU event]
  topdown-fetch-lat OR cpu_core/topdown-fetch-lat/   [Kernel PMU event]
  topdown-heavy-ops OR cpu_core/topdown-heavy-ops/   [Kernel PMU event]
  topdown-mem-bound OR cpu_core/topdown-mem-bound/   [Kernel PMU event]
  topdown-retiring OR cpu_core/topdown-retiring/     [Kernel PMU event]
  msr/pperf/                                         [Kernel PMU event]
  msr/smi/                                           [Kernel PMU event]
  msr/tsc/                                           [Kernel PMU event]

cache:
  longest_lat_cache.miss
       [Counts the number of cacheable memory requests that miss in the LLC. Counts on a per core basis. Unit: cpu_atom]
  longest_lat_cache.reference
       [Counts the number of cacheable memory requests that access the LLC. Counts on a per core basis. Unit: cpu_atom]
  mem_bound_stalls.ifetch
       [Counts the number of cycles the core is stalled due to an instruction cache or TLB miss which hit in the L2,LLC,DRAM or MMIO (Non-DRAM). Unit: cpu_atom]
  mem_bound_stalls.ifetch_dram_hit
       [Counts the number of cycles the core is stalled due to an instruction cache or TLB miss which hit in DRAM or MMIO (Non-DRAM). Unit: cpu_atom]
  mem_bound_stalls.ifetch_l2_hit
       [Counts the number of cycles the core is stalled due to an instruction cache or TLB miss which hit in the L2 cache. Unit: cpu_atom]
  mem_bound_stalls.ifetch_llc_hit
                                                                                                                                                                                                                                            
┌──(amamitsu㉿amamitsu)-[/sys/…/cpu/cpu0/cache/index3]
└─$ cd /home/amamitsu/Applications/lab1/build/         
                                                                                                                                                                                                                                                                                                                                                                                                                                                                
┌──(amamitsu㉿amamitsu)-[~/Applications/lab1/build]
└─$ sudo perf stat ./lab1_gemm 256 256 256             
GEMM performance info:
                M, K, N: 256, 256, 256
                Ops: 0.0335544
                Total compute time(s): 2.64601
                Cost(s): 0.0132301
                Benchmark(Gflops): 2.53623

 Performance counter stats for './lab1_gemm 256 256 256':

          2,783.61 msec task-clock                #    1.000 CPUs utilized          
                 5      context-switches          #    0.002 K/sec                  
                 0      cpu-migrations            #    0.000 K/sec                  
               315      page-faults               #    0.113 K/sec                  
    12,192,461,909      cycles                    #    4.380 GHz                    
    49,467,952,031      instructions              #    4.06  insn per cycle         
     3,539,861,350      branches                  # 1271.682 M/sec                  
        13,868,932      branch-misses             #    0.39% of all branches        

       2.783398170 seconds time elapsed

       2.779785000 seconds user
       0.003999000 seconds sys
                                                                                                                                                                                                                                            
┌──(amamitsu㉿amamitsu)-[~/Applications/lab1/build]
└─$ sudo perf stat -e L1-dcache-loads,L1-dcache-load-misses,dTLB-loads,dTLB-load-misses ./lab1_gemm 256 256 256
GEMM performance info:
                M, K, N: 256, 256, 256
                Ops: 0.0335544
                Total compute time(s): 2.5879
                Cost(s): 0.0129395
                Benchmark(Gflops): 2.59318

 Performance counter stats for './lab1_gemm 256 256 256':

     7,065,909,560      L1-dcache-loads                                             
       232,261,146      L1-dcache-load-misses     #    3.29% of all L1-dcache hits  
     7,065,909,560      dTLB-loads                                                  
             4,167      dTLB-load-misses          #    0.00% of all dTLB cache hits 

       2.721420500 seconds time elapsed

       2.717708000 seconds user
       0.004002000 seconds sys

Lab2 キャッシュ、ループ、およびブロッキングを使用して行列乗算を最適化

Perfで行列乗算パフォーマンスのボトルネックを特定

perf を使用してプログラムのボトルネックを確認した結果、L1/L2キャッシュのミスが主な原因であることがわかりました。

┌──(amamitsu㉿amamitsu)-[~/Applications/lab2/build]
└─$ sudo perf stat -e l2_rqsts.code_rd_hit,l2_rqsts.references,l1d.replacement,l1d_pend_miss.pending,l2_rqsts.pf_hit,l2_rqsts.pf_miss,L1-dcache-loads,L1-dcache-load-misses ./dist/bins/lab2_gemm_baseline 256 1024 256
GEMM performance info:
                M, K, N: 256, 1024, 256
                Ops: 0.134218
                Total compute time(s): 8.5102
                Cost(s): 0.042551
                Benchmark(Gflops): 3.15428

Performance counter stats for ‘./dist/bins/lab2_gemm_baseline 256 1024 256’:

           209,586     l2_rqsts.code_rd_hit                                    (49.99%)
     1,850,479,513     l2_rqsts.references                                     (50.02%)
       928,894,971     l1d.replacement                                         (50.03%)
     3,827,521,915     l1d_pend_miss.pending                                   (50.03%)
     1,137,852,364     l2_rqsts.pf_hit                                         (50.01%)
       631,431,485     l2_rqsts.pf_miss                                        (49.98%)
    28,253,616,176     L1-dcache-loads                                         (49.97%)
       929,403,945     L1-dcache-load-misses # 3.29% of all L1-dcache accesses (49.97%)

      15.272692474 seconds time elapsed

      15.255466000 seconds user
       0.000000000 seconds sys

Prefetchを利用して行列乗算の性能を最適化する

この問題に対処するため、必要なデータを事前に準備するプリフェッチ(prefetch)を使用できます。ただし、キャッシュメモリには制限があるため、プリフェッチコマンドを頻繁に使用しすぎないよう注意が必要です。最終的に、各 m ループで A[$m+1$][$k$] と C[$m$][$n+1$] をプリフェッチすることに決めました。この手法によりパフォーマンスがわずかに向上しますが、実際にはそれほど大きな影響はありません。また、この方法は特定の状況(サイズ)でのみ有効です。

.text;
.p2align 2;
.global gemm_kernel_opt_prefetch;
.type gemm_kernel_opt_prefetch, %function;

#define     MAT_C               %rdi
#define     MAT_A               %rsi
#define     MAT_B               %r14
#define     DIM_M               %rcx
#define     DIM_N               %r8
#define     DIM_K               %r9
#define     loop_m              %r10
#define     loop_k              %r11
#define     loop_n              %r12
#define     mat_elem_idx        %r13
#define     prefetch_elem_idx   %r15


.macro PUSHD
    push %rax
    push %rbx
    push %rcx
    push %rdx
    push %rsi
    push %rdi
    push %rbp
    push %r8
    push %r9
    push %r10
    push %r11
    push %r12
    push %r13
    push %r14
    push %r15
.endm

.macro POPD
    pop %r15
    pop %r14
    pop %r13
    pop %r12
    pop %r11
    pop %r10
    pop %r9
    pop %r8
    pop %rbp
    pop %rdi
    pop %rsi
    pop %rdx
    pop %rcx
    pop %rbx
    pop %rax
.endm

.macro GEMM_INIT
    mov %rdx, MAT_B

    xor loop_m, loop_m
    xor loop_k, loop_k
    xor loop_n, loop_n
.endm

.macro DO_GEMM
DO_LOOP_K:
    xor loop_m, loop_m

DO_LOOP_M:
    xor loop_n, loop_n

    mov loop_m, %rax
    mul DIM_K
    mov %rax, mat_elem_idx
    add loop_k, mat_elem_idx                    // m*K+kを計算
    flds (MAT_A, mat_elem_idx, 4)               // A[m][k]をロード

    // A[m+1][k]をPrefetch
    mov loop_m, %rax
    add $1, %rax
    mul DIM_K
    mov %rax, prefetch_elem_idx
    add loop_k, prefetch_elem_idx
    prefetcht0 (MAT_A, prefetch_elem_idx, 4)    // A[m+1][k]をPrefetch
    
    mov DIM_N, %rax
    mul loop_m
    add $1, %rax
    mov %rax, prefetch_elem_idx
    add loop_n, prefetch_elem_idx
    prefetcht0 (MAT_C, prefetch_elem_idx, 4)    // C[m][n+1]をPrefetch
    
DO_LOOP_N:
    mov DIM_N, %rax
    mul loop_k
    mov %rax, mat_elem_idx
    add loop_n, mat_elem_idx                    // k*N+nを計算
    flds (MAT_B, mat_elem_idx, 4)               // B[k][n]をロード
    fmul %st(1), %st(0)                         // A[m][k] * B[k][n]を計算

    mov DIM_N, %rax
    mul loop_m
    mov %rax, mat_elem_idx
    add loop_n, mat_elem_idx                    // m*N+nを計算
    flds (MAT_C, mat_elem_idx, 4)               // C[m][n]をロード


    faddp %st(1), %st(0)                        // C[m][n] + A[m][k] * B[k][n]を計算
    fstps (MAT_C, mat_elem_idx, 4)

    add $1, loop_n
    cmp DIM_N, loop_n
    jl DO_LOOP_N

    fstp %st(0)
    add $1, loop_m
    cmp DIM_M, loop_m
    jl DO_LOOP_M

    add $1, loop_k
    cmp DIM_K, loop_k
    jl DO_LOOP_K
.endm


gemm_kernel_opt_prefetch:
    PUSHD
    GEMM_INIT
    DO_GEMM
    POPD
    ret

出力

┌──(amamitsu㉿amamitsu)-[~/Applications/lab2/build]
└─$ ./dist/bins/lab2_gemm_kernel_opt_prefetch.unittest              
Running main() from /home/amamitsu/Applications/lab2/build/_deps/googletest-src/googletest/src/gtest_main.cc
[==========] Running 4 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 4 tests from gemm_kernel_opt_prefetch
[ RUN      ] gemm_kernel_opt_prefetch.test0
[       OK ] gemm_kernel_opt_prefetch.test0 (12 ms)
[ RUN      ] gemm_kernel_opt_prefetch.test1
[       OK ] gemm_kernel_opt_prefetch.test1 (3 ms)
[ RUN      ] gemm_kernel_opt_prefetch.test2
[       OK ] gemm_kernel_opt_prefetch.test2 (3 ms)
[ RUN      ] gemm_kernel_opt_prefetch.test3
[       OK ] gemm_kernel_opt_prefetch.test3 (0 ms)
[----------] 4 tests from gemm_kernel_opt_prefetch (20 ms total)

[----------] Global test environment tear-down
[==========] 4 tests from 1 test suite ran. (20 ms total)
[  PASSED  ] 4 tests.

┌──(amamitsu㉿amamitsu)-[~/Applications/lab2/build]
└─$ ./dist/bins/lab2_gemm_opt_prefetch 2048 64 64 
--- Performance before prefetch optimization ---
GEMM performance info:
                M, K, N: 2048, 64, 64
                Ops: 0.0167772
                Total compute time(s): 1.09243
                Cost(s): 0.00546215
                Benchmark(Gflops): 3.07154
--- Performance for after prefetch optimization ---
GEMM performance info:
                M, K, N: 2048, 64, 64
                Ops: 0.0167772
                Total compute time(s): 1.03568
                Cost(s): 0.00517838
                Benchmark(Gflops): 3.23985
----------------------------
Performance difference(Gflops): 0.168311

パフォーマンスは5.5%向上しました。

ループとブロッキングを利用して行列乗算性能を向上させる

行列積を $16\times16$ のブロックに分割しました。このブロッキングにより、キャッシュヒットが発生しやすくなるのは明らかです。さらに最適化するため、一部の命令を手動でアラインメントしました。
x86アセンブリロジックが複雑なため、同等とみなせるC言語のコードも以下に示します。

void gemm_kernel_opt_loop(float *C, const float *A, const float *B, int M, int N, int K) {
    for (int m = 0; m < M/16; m += 16)
        for (int n = 0; n < N/16; n += 16)
            for (int k = 0; k < K/16; k += 16) {
                int minMt = m+16< M ? m+16:M;
                int minNt = n+16< N ? n+16:N;
                int minKt = k+16< K ? k+16:K;
                for (int mt = m; mt < minMt; mt++)
                    for (int nt = n; nt < minNt; nt++)
                        for (int kt = k; kt < minKt; kt++)
                            C[mt * M + nt] += A[mt * M + kt] * B[kt * K + nt];
            }
}
.text;
.p2align 2;
.global gemm_kernel_opt_loop;
.type gemm_kernel_opt_loop, %function;

#define     MAT_C               %rdi
#define     MAT_A               %rsi
#define     MAT_B               %r14
#define     DIM_M               %rcx
#define     DIM_N               %r8d
#define     DIM_K               %r9d
#define     loop_m              %r10d
#define     loop_k              %r11d
#define     loop_n              %r12


.macro PUSHD
    push %rax
    push %rbx
    push %rcx
    push %rdx
    push %rsi
    push %rdi
    push %rbp
    push %r8
    push %r9
    push %r10
    push %r11
    push %r12
    push %r13
    push %r14
    push %r15
.endm

.macro POPD
    pop %r15
    pop %r14
    pop %r13
    pop %r12
    pop %r11
    pop %r10
    pop %r9
    pop %r8
    pop %rbp
    pop %rdi
    pop %rsi
    pop %rdx
    pop %rcx
    pop %rbx
    pop %rax
.endm

.macro DO_GEMM
    //TODO:
    //16*16 blocks in 6 loops
	mov	$0, loop_m
	cmp	%ecx, loop_m
	jge	OVER
	mov	%rdx, %r13
	mov	DIM_N, %eax
	mov	DIM_N, -56(%rsp)
	cdqe
	lea	0(,%rax,4), %r15
	jmp	LOOP_1_VARSET
LOOP_6_VARSET:
	mov	%r13, %rdx
	mov	-80(%rsp), %rax
	fldz
	.p2align 5
LOOP_6:
	flds	(%rax)
	fmuls	(%rdx)
	faddp	%st, %st(1)
	add	$4, %rax
	add	%r15, %rdx
	cmp	%rbx, %rax
	jne	LOOP_6
LOOP_5_OVERCHECK:
	mov	-96(%rsp), %rax
	fadds	(%rax,%rbp,4)
	fstps	(%rax,%rbp,4)
	add	$1, %rbp
	add	$4, %r13
	cmp	%ebp, -88(%rsp)
	jle	LOOP_4_OVERCHECK
LOOP_5_VARSET:
	mov	-104(%rsp), %edx
	cmp	%edx, -84(%rsp)
	jg	LOOP_6_VARSET
	fldz
	jmp	LOOP_5_OVERCHECK
	.p2align 6
LOOP_4_OVERCHECK:
	addq	$1, -72(%rsp)
	mov	-72(%rsp), %eax
	mov	-56(%rsp), %edx
	add	%edx, -68(%rsp)
	mov	-16(%rsp), %edx
	add	%edx, -64(%rsp)
	cmp	%eax, -52(%rsp)
	je	BLOCK_COMPLETE
LOOP_4_VARSET:
	mov	-88(%rsp), %ebx
	cmp	%ebx, -60(%rsp)
	jge	LOOP_4_OVERCHECK
	movsxd	-64(%rsp), %rdx
	mov	-104(%rsp), %rbx
	lea	(%rdx,%rbx), %rax
	mov	-48(%rsp), %rbp
	lea	0(%rbp,%rax,4), %rax
	mov	%rax, -80(%rsp)
	mov	-12(%rsp), %eax
	add	%rbx, %rax
	add	%rdx, %rax
	lea	0(%rbp,%rax,4), %rbx
	movsxd	-68(%rsp), %rax
	mov	-40(%rsp), %rdx
	lea	(%rdx,%rax,4), %rax
	mov	%rax, -96(%rsp)
	mov	-32(%rsp), %r13
	mov	-24(%rsp), %rbp
	jmp	LOOP_5_VARSET
BLOCK_COMPLETE:
	mov	-8(%rsp), %r13
LOOP_3_OVERCHECK:
	mov	DIM_K, %ebx
	addq	$16, -104(%rsp)
	mov	-104(%rsp), %rax
	cmp	%eax, DIM_K
	jle	LOOP_2_OVERCHECK
LOOP_3_VARSET:
	mov	-104(%rsp), %rbp
	mov	%ebp, -96(%rsp)
	mov	loop_m, -72(%rsp)
	lea	16(loop_m), %eax
	cmp	%ecx, %eax
	cmovg	%ecx, %eax
	mov	%eax, %edx
	mov	%eax, -52(%rsp)
	mov	loop_k, -60(%rsp)
	mov	-104(%rsp), %eax
	add	$16, %eax
	cmp	%ebx, %eax
	cmovg	%ebx, %eax
	mov	%eax, -84(%rsp)
	cmp	%edx, loop_m
	jge	LOOP_3_OVERCHECK
	mov	MAT_A, -48(%rsp)
	mov	MAT_C, -40(%rsp)
	mov	-56(%rsp), %eax
	mov	%eax, %edx
	imul	loop_m, %edx
	mov	%edx, -68(%rsp)
	mov	%ebx, -16(%rsp)
	imul	loop_m, %ebx
	mov	%ebx, -64(%rsp)
	movsxd	loop_k, %rbx
	mov	%rbx, -24(%rsp)
	mov	-104(%rsp), %edx
	imul	%edx, %eax
	cdqe
	add	%rbx, %rax
	lea	0(%r13,%rax,4), %rax
	mov	%rax, -32(%rsp)
	mov	-84(%rsp), %ebp
	mov	-96(%rsp), %eax
	sub	%eax, %ebp
	mov	%ebp, -12(%rsp)
	mov	%r13, -8(%rsp)
	jmp	LOOP_4_VARSET
LOOP_2_OVERCHECK:
	lea	16(loop_k), %eax
	mov	%eax, loop_k
	cmp	DIM_N, %eax
	jge	LOOP_1_OVERCHECK
LOOP_2_VARSET:
	mov	DIM_K, %ebx
	test	DIM_K, DIM_K
	jle	LOOP_2_OVERCHECK
	movq	$0, -104(%rsp)
	lea	16(loop_k), %eax
	cmp	DIM_N, %eax
	cmovg	DIM_N, %eax
	mov	%eax, -88(%rsp)
	jmp	LOOP_3_VARSET
LOOP_1_OVERCHECK:
	lea	16(loop_m), %eax
	mov	%eax, loop_m
	cmp	%ecx, %eax
	jge	OVER
LOOP_1_VARSET:
	mov	$0, loop_k
	test	DIM_N, DIM_N
	jg	LOOP_2_VARSET
	jmp	LOOP_1_OVERCHECK
OVER:
.endm

gemm_kernel_opt_loop:
    PUSHD
    DO_GEMM
    POPD
    ret

出力

┌──(amamitsu㉿amamitsu)-[~/Applications/lab2/build]
└─$ ./dist/bins/lab2_gemm_kernel_opt_loop.unittest
Running main() from /home/amamitsu/Applications/lab2/build/_deps/googletest-src/googletest/src/gtest_main.cc
[==========] Running 4 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 4 tests from gemm_kernel_opt_loop
[ RUN      ] gemm_kernel_opt_loop.test0
[       OK ] gemm_kernel_opt_loop.test0 (1 ms)
[ RUN      ] gemm_kernel_opt_loop.test1
[       OK ] gemm_kernel_opt_loop.test1 (2 ms)
[ RUN      ] gemm_kernel_opt_loop.test2
[       OK ] gemm_kernel_opt_loop.test2 (2 ms)
[ RUN      ] gemm_kernel_opt_loop.test3
[       OK ] gemm_kernel_opt_loop.test3 (0 ms)
[----------] 4 tests from gemm_kernel_opt_loop (6 ms total)

[----------] Global test environment tear-down
[==========] 4 tests from 1 test suite ran. (6 ms total)
[  PASSED  ] 4 tests.

┌──(amamitsu㉿amamitsu)-[~/Applications/lab2/build]
└─$ ./dist/bins/lab2_gemm_opt_loop 2048 512 64                                       
--- Performance before loop optimization ---
GEMM performance info:
                M, K, N: 2048, 512, 64
                Ops: 0.134218
                Total compute time(s): 8.63971
                Cost(s): 0.0431986
                Benchmark(Gflops): 3.107
--- Performance for after loop optimization ---
GEMM performance info:
                M, K, N: 2048, 512, 64
                Ops: 0.134218
                Total compute time(s): 4.44931
                Cost(s): 0.0222465
                Benchmark(Gflops): 6.0332
----------------------------
Performance difference(Gflops): 2.9262

┌──(amamitsu㉿amamitsu)-[~/Applications/lab2/build]
└─$ ./dist/bins/lab2_gemm_opt_loop 4444 444 444       
--- Performance before loop optimization ---
GEMM performance info:
                M, K, N: 4444, 444, 444
                Ops: 1.75214
                Total compute time(s): 109.722
                Cost(s): 0.548611
                Benchmark(Gflops): 3.19378
--- Performance for after loop optimization ---
GEMM performance info:
                M, K, N: 4444, 444, 444
                Ops: 1.75214
                Total compute time(s): 58.2671
                Cost(s): 0.291336
                Benchmark(Gflops): 6.01418
----------------------------
Performance difference(Gflops): 2.8204

アセンブリコードを慎重に最適化することで、パフォーマンスは94.2%向上しました。ちなみに、命令をアラインしない場合でも44.3%の向上が見られ、これでも元のバージョンと比較して十分効果的です。

Lab3 命令レベルの並列性、ベクトル命令、および並列処理を使用して行列乗算を最適化

x87 FPUを基に行列乗算の性能を最適化する

ここでは、x87 FPU 命令を使用してプロセスを最適化します。この部分は以下のように分けられます:

(1) A[$m$][$k$] $\times$ B[$k$][$n+1$] を計算する

(2) C[$m$][$n$] を st(1) に、C[$m$][$n+1$] を st(0) にロードする

(3) C[$m$][$n+1$] + A[$m$][$k$] $\times$ B[$k$][$n+1$] を計算する

(4) C[$m$][$n$] + A[$m$][$k$] $\times$ B[$k$][$n$] を計算する

(5) C[$m$][$n$] を保存する

(6) n ループのカウンタを更新する

ここでの重要なポイントは、命令が浮動小数点レジスタにどのような変化をもたらすかを理解することです。

コードより、各 n ループで 2 つの乗算が処理されることがわかります。そのため、上記のプロセスが終了したら、n のループカウンタに 2 を加えます。

.text;
.p2align 2;
.global gemm_kernel_opt_loop_unrolling;
.type gemm_kernel_opt_loop_unrolling, %function;

#define     MAT_C               %rdi
#define     MAT_A               %rsi
#define     MAT_B               %r14
#define     DIM_M               %rcx
#define     DIM_N               %r8
#define     DIM_K               %r9
#define     loop_m              %r10
#define     loop_k              %r11
#define     loop_n              %r12
#define     mat_elem_idx        %r13


.macro PUSHD   // 現在の汎用レジスタの値を保存
    push %rax
    push %rbx
    push %rcx
    push %rdx
    push %rsi
    push %rdi
    push %rbp
    push %r8
    push %r9
    push %r10
    push %r11
    push %r12
    push %r13
    push %r14
    push %r15
.endm

.macro POPD    // 保存した汎用レジスタの値を復元
    pop %r15
    pop %r14
    pop %r13
    pop %r12
    pop %r11
    pop %r10
    pop %r9
    pop %r8
    pop %rbp
    pop %rdi
    pop %rsi
    pop %rdx
    pop %rcx
    pop %rbx
    pop %rax
.endm

.macro GEMM_INIT
    mov %rdx, MAT_B

    xor loop_m, loop_m
    xor loop_k, loop_k
    xor loop_n, loop_n
.endm

.macro DO_GEMM
DO_LOOP_K:
    xor loop_m, loop_m

DO_LOOP_M:
    xor loop_n, loop_n

    mov loop_m, %rax
    mul DIM_K
    mov %rax, mat_elem_idx
    add loop_k, mat_elem_idx          // m * K + kを計算
    flds (MAT_A, mat_elem_idx, 4)     // A[m][k]を読み込む

DO_LOOP_N:
    mov DIM_N, %rax
    mul loop_k
    mov %rax, mat_elem_idx
    add loop_n, mat_elem_idx
    flds (MAT_B, mat_elem_idx, 4)     // B[k][n]を読み込む
    fmul %st(1), %st(0)               // A[m][k] * B[k][n] を計算 --> st(0)

    // TODO: A[m][k] * B[k][n+1] を計算するロジックを追加してください


    mov DIM_N, %rax
    mul loop_m
    mov %rax, mat_elem_idx
    add loop_n, mat_elem_idx          // m * N + nを計算
    // TODO: C[m][n] を st(1)、C[m][n+1] を st(0) に読み込むロジックを追加してください

    // TODO: 部分和を累積するロジックを追加してください
    // C[m][n+1] + A[m][k] * B[k][n+1]、C[m][n] + A[m][k] * B[k][n]

    fstps (MAT_C, mat_elem_idx, 4)    // C[m][n+1] を保存

    // TODO: C[m][n] を保存するロジックを追加してください

    // TODO: N次元ループを更新するロジックを追加してください

    cmp DIM_N, loop_n
    jl DO_LOOP_N

    fstp %st(0)                   // st(0) のみをポップ
    add $1, loop_m
    cmp DIM_M, loop_m
    jl DO_LOOP_M

    add $1, loop_k
    cmp DIM_K, loop_k
    jl DO_LOOP_K
.endm

gemm_kernel_opt_loop_unrolling:
    PUSHD
    GEMM_INIT
    DO_GEMM
    POPD
    ret
Stack Top
B[k][n]st(0)
A[m][k]st(1)
flds (MAT_B, mat_elem_idx, 4)
Stack Top
B[k][n]*A[m][k]st(0)
A[m][k]st(1)
fmul %st(1), %st(0)
Stack Top
B[k][n+1]st(0)
B[k][n]*A[m][k]st(1)
A[m][k]st(2)
add $1, mat_elem_idx
flds (MAT_B, mat_elem_idx, 4)
Stack Top
B[k][n+1]*A[m][k]st(0)
B[k][n]*A[m][k]st(1)
A[m][k]st(2)
fmul %st(2), %st(0)
Stack Top
C[m][n]st(0)
B[k][n+1]*A[m][k]st(1)
B[k][n]*A[m][k]st(2)
flds (MAT_C, mat_elem_idx, 4)
Stack Top
C[m][n+1]st(0)
C[m][n]st(1)
B[k][n+1]*A[m][k]st(2)
B[k][n]*A[m][k]st(3)
add $1, mat_elem_idx
flds (MAT_C, mat_elem_idx, 4)
Stack Top
C[m][n]st(0)
B[k][n+1]*A[m][k]+C[m][n+1]st(1)
B[k][n]*A[m][k]st(2)
faddp %st(2), %st(0)
Stack Top
B[k][n+1]*A[m][k]+C[m][n+1]st(0)
B[k][n]*A[m][k]+C[m][n]st(1)
faddp %st(2), %st(0)

出力

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab3/build]
└─$ ./dist/bins/lab3_gemm_opt_loop_unrolling.unittest              
Running main() from /home/amamitsu/Applications/Lab3/build/_deps/googletest-src/googletest/src/gtest_main.cc
[==========] Running 3 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 3 tests from gemm_kernel_opt_loop_unrolling
[ RUN      ] gemm_kernel_opt_loop_unrolling.test0
[       OK ] gemm_kernel_opt_loop_unrolling.test0 (1 ms)
[ RUN      ] gemm_kernel_opt_loop_unrolling.test1
[       OK ] gemm_kernel_opt_loop_unrolling.test1 (0 ms)
[ RUN      ] gemm_kernel_opt_loop_unrolling.test2
[       OK ] gemm_kernel_opt_loop_unrolling.test2 (1 ms)
[----------] 3 tests from gemm_kernel_opt_loop_unrolling (3 ms total)

[----------] Global test environment tear-down
[==========] 3 tests from 1 test suite ran. (3 ms total)
[  PASSED  ] 3 tests.

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab3/build]
└─$ ./dist/bins/lab3_gemm_opt_loop_unrolling 2048 512 64
--- Performance before loop unrolling optimization ---
GEMM performance info:
                M, K, N: 2048, 512, 64
                Ops: 0.134218
                Total compute time(s): 8.9804
                Cost(s): 0.044902
                Benchmark(Gflops): 2.98913
--- Performance for after loop unrolling optimization ---
GEMM performance info:
                M, K, N: 2048, 512, 64
                Ops: 0.134218
                Total compute time(s): 7.44397
                Cost(s): 0.0372198
                Benchmark(Gflops): 3.60608
----------------------------
Performance difference(Gflops): 0.616955

ループを 1 回展開することで、性能が 20% 向上します。

AVX命令の(2m,32n,32k)高性能行列乗算計算カーネル

このタスクでは、ブロードキャストとデータをベクタレジスタにロードするロジックを実装します。コード内の例から、まずブロードキャストする要素のアドレスを取得し、その後 $\textbf{vbroadcastss} $を使用してブロードキャストします。

レジスタにデータをロードする際には、$\textbf{vmovups} $を使用します。

計算では、$\textbf{vfmadd231ps} $を使用して、以下のような演算を実行します: C[$m$][$n$] += A[$m$][$k$] $\times$ B[$k$][$n$]

.text;
.p2align 2;
.global gemm_kernel_opt_avx;
.type gemm_kernel_opt_avx, %function;


#define     AVX_REG_BYTE_WIDTH  32

#define     MAT_C               %rdi
#define     MAT_A               %rsi
#define     MAT_B               %r13
#define     DIM_M               %rcx
#define     DIM_N               %r8
#define     DIM_K               %r9
#define     loop_m              %r10
#define     loop_k              %r11
#define     loop_n              %r12
#define     mat_elem_idx        %r14
#define     temp_reg            %r15

// 以下は計算中に使用されるAVXレジスタ
#define     mat_c0_0_8           %ymm0
#define     mat_c0_8_16          %ymm1
#define     mat_c0_16_24         %ymm2
#define     mat_c0_24_32         %ymm3
#define     mat_c1_0_8           %ymm4
#define     mat_c1_8_16          %ymm5
#define     mat_c1_16_24         %ymm6
#define     mat_c1_24_32         %ymm7
#define     mat_a0_0_8           %ymm8
#define     mat_a1_0_8           %ymm9
#define     mat_b0_0_8           %ymm10
#define     mat_b0_8_16          %ymm11
#define     mat_b0_16_24         %ymm12
#define     mat_b0_24_32         %ymm13

.macro PUSHD
    push %rax
    push %rbx
    push %rcx
    push %rdx
    push %rsi
    push %rdi
    push %rbp
    push %r8
    push %r9
    push %r10
    push %r11
    push %r12
    push %r13
    push %r14
    push %r15
.endm

.macro POPD
    pop %r15
    pop %r14
    pop %r13
    pop %r12
    pop %r11
    pop %r10
    pop %r9
    pop %r8
    pop %rbp
    pop %rdi
    pop %rsi
    pop %rdx
    pop %rcx
    pop %rbx
    pop %rax
.endm

.macro GEMM_INIT
    mov %rdx, MAT_B
.endm

.macro LOAD_MAT_A     // 行列Aの同じ列の2つ (A[m][k], A[m+1][k])をロード
    // A[m][k]のデータをロード
    mov loop_m, %rax
    mul DIM_K
    mov %rax, temp_reg
    add loop_k, temp_reg

    // A[m][k]のアドレスを計算
    mov temp_reg, mat_elem_idx
    shl $2, mat_elem_idx        // *=4

    vbroadcastss (MAT_A, mat_elem_idx), mat_a0_0_8    // A[m][k]をAVXレジスタの8つのセルにブロードキャスト

    // TODO: A[m+1][k]をロードしてmat_a1_0_8にブロードキャストするロジックを追加して

.endm

.macro LOAD_MAT_B    // 行列Bの1行32個の要素をロード (B[k][n:n+32])

    // TODO: B[k][n:n+32]をロードしてmat_b0_0_8, mat_b0_8_16, mat_b0_16_24, mat_b0_24_32に格納するロジックを追加して

.endm

.macro LOAD_MAT_C
    mov loop_m, %rax
    mul DIM_N
    mov %rax, temp_reg
    add loop_n, temp_reg

    // 行列Cの最初の行 (C[m][n:n+32]) をロード
    mov temp_reg, mat_elem_idx
    shl $2, mat_elem_idx        // *=4

    // TODO: C[m][n:n+32]をロードしてmat_c0_0_8, mat_c0_8_16, mat_c0_16_24, mat_c0_24_32に格納するロジックを追加して

    // 行列Cの2行目 (C[m+1][n:n+32]) をロード
    mov temp_reg, mat_elem_idx
    add DIM_N, mat_elem_idx
    shl $2, mat_elem_idx        // *=4

    // TODO: C[m+1][n:n+32]をロードしてmat_c1_0_8, mat_c1_8_16, mat_c1_16_24, mat_c1_24_32に格納するロジックを追加して

.endm

.macro STORE_MAT_C
    mov loop_m, %rax
    mul DIM_N
    mov %rax, temp_reg
    add loop_n, temp_reg

    // 行列Cの最初の行データを保存, C[m][n:n+32]
    mov temp_reg, mat_elem_idx
    shl $2, mat_elem_idx        // *=4

    // TODO: mat_c0_0_8, mat_c0_8_16, mat_c0_16_24, mat_c0_24_32を保存してC[m][n:n+32]に格納するロジックを追加

    // 行列Cの2行目のデータを保存, C[m+1][n:n+32]
    // TODO: mat_c1_0_8, mat_c1_8_16, mat_c1_16_24, mat_c1_24_32を保存してC[m+1][n:n+32]に格納するロジックを追加

.endm

.macro DO_COMPUTE      // C[m:m+2][n:n+32] += A[m:m+2][k] * B[k:k+8][n:n+32] の計算を実行

    // TODO: C[m:m+2][n:n+32] += A[m:m+2][k] * B[k:k+8][n:n+32] の計算ロジックを追加

.endm


.macro DO_GEMM
    xor loop_n, loop_n
DO_LOOP_N:

    xor loop_m, loop_m
DO_LOOP_M:
    // 行列Cのデータをロード
    LOAD_MAT_C

    xor loop_k, loop_k
DO_LOOP_K:
    // 行列Aおよび行列Bの分割されたデータをロード
    LOAD_MAT_A
    LOAD_MAT_B

    DO_COMPUTE

    add $1, loop_k              // kr=1
    cmp DIM_K, loop_k
    jl DO_LOOP_K

    // 結果を保存
    STORE_MAT_C

    add $2, loop_m              // mr=2
    cmp DIM_M, loop_m
    jl DO_LOOP_M

    add $32, loop_n             // nr=32
    cmp DIM_N, loop_n
    jl DO_LOOP_N

.endm

gemm_kernel_opt_avx:
    PUSHD
    GEMM_INIT
    DO_GEMM
    POPD
    ret

出力

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab3/build]
└─$ ./dist/bins/lab3_gemm_opt_avx.unittest              
Running main() from /home/amamitsu/Applications/Lab3/build/_deps/googletest-src/googletest/src/gtest_main.cc
[==========] Running 3 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 3 tests from gemm_kernel_opt_avx
[ RUN      ] gemm_kernel_opt_avx.test0
[       OK ] gemm_kernel_opt_avx.test0 (0 ms)
[ RUN      ] gemm_kernel_opt_avx.test1
[       OK ] gemm_kernel_opt_avx.test1 (0 ms)
[ RUN      ] gemm_kernel_opt_avx.test2
[       OK ] gemm_kernel_opt_avx.test2 (1 ms)
[----------] 3 tests from gemm_kernel_opt_avx (1 ms total)

[----------] Global test environment tear-down
[==========] 3 tests from 1 test suite ran. (2 ms total)
[  PASSED  ] 3 tests.

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab3/build]
└─$ ./dist/bins/lab3_gemm_opt_avx 2048 512 64  
--- Performance before avx optimization ---
GEMM performance info:
                M, K, N: 2048, 512, 64
                Ops: 0.134218
                Total compute time(s): 8.98306
                Cost(s): 0.0449153
                Benchmark(Gflops): 2.98824
--- Performance for after avx optimization ---
GEMM performance info:
                M, K, N: 2048, 512, 64
                Ops: 0.134218
                Total compute time(s): 0.319307
                Cost(s): 0.00159654
                Benchmark(Gflops): 84.0681
----------------------------
Performance difference(Gflops): 81.0799

性能は 2713% 向上しました。これは、1 つの命令で大量のデータを処理できるためです。

OpenMPとAVX命令で任意形状の行列乗算を実現する

ここでは、計算を正方形に近いブロックに分割しようとしています。例えば、スレッドが 12 個ある場合、行列 C を $3\times 4$ の部分に分割し、必要に応じてパディングを加えて関連データを準備します。

#include <omp.h>
#include "openmp_gemm.h"
#include "gemm_kernel_opt.h"
#include <cstring>

inline int get_parallel_thread_num(uint64_t M, uint64_t K, uint64_t N, int kernel_mr, int kernel_nr, int max_threads, int& m_thread, int& n_thread) {
    m_thread = 2;
    if (m_thread > max_threads) {
        m_thread = max_threads;
    }
    n_thread = max_threads / m_thread;
    return m_thread * n_thread;
}

void openmp_gemm_baseline(int thread_num, float *C, float *A, float *B, uint64_t M, uint64_t N, uint64_t K){
    // 定数
    const int KERNEL_MR = 2, KERNEL_NR = 32;    // TODO: AVXカーネルによる変更
    int m_thread = 1, n_thread = 1;
    
    int real_thread_num = get_parallel_thread_num(M, K, N, KERNEL_MR, KERNEL_NR, thread_num, m_thread, n_thread);

#pragma omp parallel num_threads(real_thread_num) \
            default(none) \
            shared(C) \
            firstprivate(A, B, M, N, K, KERNEL_MR, KERNEL_NR, \
                m_thread, n_thread)
    {
        const int kernel_size = KERNEL_MR * KERNEL_NR;
        int thread_id = omp_get_thread_num();  // スレッドの割り当ては行優先の方式を採用し、各行についてスレッド番号は 0, 1, 2, 3, ... の順となる。
        /* 3つの次元のインデックスを計算する */
        int thread_id_m = thread_id / n_thread;  // M次元のインデックス
        int thread_id_n = thread_id % n_thread;  // N次元のインデックス
        /* 3つの次元の計算開始位置のインデックスを計算する */

        // M次元に割り当てられた行数を計算する
        int dim_m_per_thread = (M + m_thread - 1) / m_thread; // M次元で分割可能なブロック数(不完全なブロックを含む)
        int m_padding = dim_m_per_thread % KERNEL_MR;
        if (m_padding != 0) {
            m_padding = KERNEL_MR - m_padding;
        }
        int dim_n_per_thread = (N + n_thread - 1) / n_thread; // N次元で分割可能なブロック数(不完全なブロックを含む)
        int n_padding = dim_n_per_thread % KERNEL_NR;
        if (n_padding != 0) {
            n_padding = KERNEL_NR - n_padding;
        }

        // 最初のステップのループはM次元から始める。このようにすることで、Bを共有し、Cを累積計算する必要がなくなる。
        // そのため、この時点でAとCの開始位置を再計算する必要がある。
        int thread_m_start = thread_id_m * dim_m_per_thread;
        int thread_m_end = thread_m_start + dim_m_per_thread;
        if (thread_m_end > M) {
            thread_m_end = M;
        }

        // N次元のstart、endインデックスを計算
        int thread_n_start = thread_id_n * dim_n_per_thread;
        int thread_n_end = thread_n_start + dim_n_per_thread;
        if (thread_n_end > N) {
            thread_n_end = N;
        }

        // 三つの行列のメモリを取得
        auto A_padding = new float[(dim_m_per_thread + m_padding) * K];
        memset((void *) A_padding, 0, (dim_m_per_thread + m_padding) * K * sizeof(float));
        auto B_padding = new float[(dim_n_per_thread + n_padding) * K];
        memset((void *) B_padding, 0, (dim_n_per_thread + n_padding) * K * sizeof(float));
        auto C_padding = new float[(dim_m_per_thread + m_padding) * (dim_n_per_thread + n_padding)];
        memset((void *) C_padding, 0, (dim_m_per_thread + m_padding) * (dim_n_per_thread + n_padding) * sizeof(float));

        // データをコーピ
        for (int m = thread_m_start; m < thread_m_end; m++) {
            memcpy(A_padding + (m - thread_m_start) * K, A + m * K, K * sizeof(float));
        }

        for (int k = 0; k < K; k++) {
            memcpy(B_padding + k * (dim_n_per_thread + n_padding),
                   B + thread_n_start + k * N,
                   (thread_n_end - thread_n_start) * sizeof(float));
        }

        for (int m = thread_m_start; m < thread_m_end; m++) {
            memcpy(C_padding + (m - thread_m_start) * (dim_n_per_thread + n_padding),
                   C + m * N + thread_n_start,
                   (thread_n_end - thread_n_start) * sizeof(float));
        }

        // カーネルを利用して計算
        gemm_kernel_opt_avx(C_padding, A_padding, B_padding, (dim_m_per_thread + m_padding),
                                 (dim_n_per_thread + n_padding), K);

        // 結果を書き込む
        for (int m = thread_m_start; m < thread_m_end; m++) {
            memcpy(C + m * N + thread_n_start,
                   C_padding + (m - thread_m_start) * (dim_n_per_thread + n_padding),
                   (thread_n_end - thread_n_start) * sizeof(float));
        }

        delete[] A_padding;
        delete[] B_padding;
        delete[] C_padding;
    }
}

出力

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab3/build]
└─$ ./dist/bins/lab3_gemm_opt_openmp.unittest 
Running main() from /home/amamitsu/Applications/Lab3/build/_deps/googletest-src/googletest/src/gtest_main.cc
[==========] Running 4 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 4 tests from openmp_gemm_opt
[ RUN      ] openmp_gemm_opt.test0
[       OK ] openmp_gemm_opt.test0 (0 ms)
[ RUN      ] openmp_gemm_opt.test1
[       OK ] openmp_gemm_opt.test1 (40 ms)
[ RUN      ] openmp_gemm_opt.test2
[       OK ] openmp_gemm_opt.test2 (282 ms)
[ RUN      ] openmp_gemm_opt.test3
[       OK ] openmp_gemm_opt.test3 (2443 ms)
[----------] 4 tests from openmp_gemm_opt (2766 ms total)

[----------] Global test environment tear-down
[==========] 4 tests from 1 test suite ran. (2767 ms total)
[  PASSED  ] 4 tests.

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab3/build]
└─$ ./dist/bins/lab3_gemm_opt_openmp 12 256 256 256
--- Performance before openmp strategy optimization ---
GEMM performance info:
                M, K, N: 256, 256, 256
                Ops: 0.0335544
                Total compute time(s): 0.041128
                Cost(s): 0.00020564
                Benchmark(Gflops): 163.171
--- Performance for after openmp strategy optimization ---
GEMM performance info:
                M, K, N: 256, 256, 256
                Ops: 0.0335544
                Total compute time(s): 0.019356
                Cost(s): 9.678e-05
                Benchmark(Gflops): 346.708
----------------------------
Performance difference(Gflops): 183.538

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab3/build]
└─$ ./dist/bins/lab3_gemm_opt_openmp 7 448 448 448
--- Performance before openmp strategy optimization ---
GEMM performance info:
                M, K, N: 448, 448, 448
                Ops: 0.179831
                Total compute time(s): 0.142757
                Cost(s): 0.000713785
                Benchmark(Gflops): 251.94
--- Performance for after openmp strategy optimization ---
GEMM performance info:
                M, K, N: 448, 448, 448
                Ops: 0.179831
                Total compute time(s): 0.080035
                Cost(s): 0.000400175
                Benchmark(Gflops): 449.38
----------------------------
Performance difference(Gflops): 197.441

性能は約 100% 向上しました。これは、最適化されたバージョンのローカリティが向上したためです。
その故は、スレッドを 2 次元の計算単位に分割する。元の OpenMP を使用したバージョンでは、スレッドの配置が線形になっていました。

Lab4 CUDA:GPUの行列乗算

本タスクでは、コード内の重要な計算ロジックを完成させることを目指します。実は簡単です。

関数MatrixMulKernelは、2つの行列\(M\)と\(N\)の積を計算し、その結果を行列\(P\)に格納します。

float* d_M: 入力行列(M)のポインタ
float* d_N: 入力行列(N)のポインタ
float* d_P: 出力行列(P)のポインタ
int width: 行列の幅(正方行列の場合)

CUDAカーネル内の各スレッドは、結果行列\(P\)の1つの部分を計算します。各スレッドの行と列のインデックスは以下のように計算されます:
\[\text{row} = \text{blockIdx.y} \times \text{blockDim.y} + \text{threadIdx.y}\]
\[\text{col} = \text{blockIdx.x} \times \text{blockDim.x} + \text{threadIdx.x}\]

次の条件により、スレッドが行列の有効な要素のみを操作することを保証します:
\[\text{if } (\text{row} < \text{width}) \text{ and } (\text{col} < \text{width})\]

関数は次の式を使用して、結果行列(P)の各要素を計算します:
\[P[\text{row}, \text{col}] = \sum_{k=0}^{\text{width}-1} M[\text{row}, k] \times N[k, \text{col}]\]

これをコードで実装すると以下のようになります:

// #define USE_CUBLAS

#include <iostream>
#include <cstdio>
#include <cuda_runtime.h>
#ifdef USE_CUBLAS
#include <cublas_v2.h>
#endif
#include <device_launch_parameters.h>
#include <cmath>
using namespace std;

const int TILE_WIDTH = 16;	// 定义块block大小

/////////
// Matrix multiplication with shared memory (CUDA Kernel) on the device: C = A * B
/////////
const int BLOCK_SIZE = TILE_WIDTH;
__global__ void MatrixMulSharedMemKernel(float *A,
    float *B, float *C, int wA,
    int wB) {




}


//! For square matrices only
__global__ void MatrixMulKernel(float* d_M, float* d_N, float* d_P, int width)
{
  // Calculate the row index of the P element and M
  // *** TO DO: Compute the row index for the current thread ***
  // int row = ...;

  // Calculate the column index of the P element and N
  // *** TO DO: Compute the column index for the current thread ***
  // int col = ...;

  // Ensure the thread is within bounds
  if ( (row < width) && (col < width) ) {
    float pValue = 0.0;

    // Each thread computes one element of the matrix
    // *** TO DO: Implement the matrix multiplication for a single element ***


    // Store the computed value into the output matrix
    // *** TO DO: Write the computed value to the correct position in d_P ***
    // d_P[row * width + col] = ...;
  }
}

////////////////////////////////////////////////////////////////////////////////
//! Compute reference data set matrix multiply on CPU
//! C = A * B
//! @param C          reference data, computed but preallocated
//! @param A          matrix A as provided to device
//! @param B          matrix B as provided to device
//! @param hA         height of matrix A
//! @param wA         width of matrix A
//! @param wB         width of matrix B
////////////////////////////////////////////////////////////////////////////////
void
matrixMulCPU(float *C, const float *A, const float *B, unsigned int hA, unsigned int wA, unsigned int wB)
{
    for (unsigned int i = 0; i < hA; ++i)
        for (unsigned int j = 0; j < wB; ++j)
        {
            double sum = 0;

            for (unsigned int k = 0; k < wA; ++k)
            {
                double a = A[i * wA + k];
                double b = B[k * wB + j];
                sum += a * b;
            }

            C[i * wB + j] = (float)sum;
        }
}

void printDiff(float *data1, float *data2, int width, int height, int iListLength, float fListTol)
{
    printf("Listing first %d Differences > %.6f...\n", iListLength, fListTol);
    int i,j,k;
    int error_count=0;

    for (j = 0; j < height; j++)
    {
        for (i = 0; i < width; i++)
        {
            k = j * width + i;
            float fDiff = fabs(data1[k] - data2[k]);

            if (fDiff > fListTol)
            {
                if (error_count < iListLength)
                {
                    printf("    Loc(%d,%d)\tCPU=%.5f\tGPU=%.5f\tDiff=%.6f\n", i, j, data1[k], data2[k], fDiff);
                }

                error_count++;
            }
        }
    }

    printf(" \n  Total Errors = %d\n", error_count);
}

void getArg(int argc, char* argv[], int &size, int &check)
{
  if (argc != 3)
  {
    cerr << "Usage: " << argv[0] << " <check_enable> <size>\n";
    cerr << "\tcheck_enable: 1 to enable result checking\n";
    cerr << "\tsize: size of the matrix\n";
    exit(1);
  }

  int val1, val2;
  try
  {
    val1 = stoi(argv[1]);
    val2 = stoi(argv[2]);
  }
  catch (const invalid_argument& e)
  {
    cerr << "ERROR: parameters should be integer\n";
    exit(1);
  }

  check = val1;
  size = val2;
}



int main(int argc, char* argv[])
{
  int size, check;
  getArg(argc, argv, size, check);

  int m = size, n = size, k = size;
  
  // 声明存放在GPU上的数组
  float *h_M, *h_N, *d_M, *d_N;
  float *h_P, *d_P;
  
  size_t sizeM = m * k * sizeof(float);
  size_t sizeN = k * n * sizeof(float);
  size_t sizeP = m * n * sizeof(float);

  // Allocate host memory
  h_M = (float*) malloc(sizeM);
  h_N = (float*) malloc(sizeN);
  h_P = (float*) malloc(sizeP);
  float *reference = (float *)malloc(sizeP);

  // Allocate device memory
  cudaMalloc(&d_M, sizeM);
  cudaMalloc(&d_N, sizeN);
  cudaMalloc(&d_P, sizeP);

  // Init data 
  for(int i = 0; i < m * n; ++i)
  {
    if(i % 2 == 0)
      h_M[i] = 1.0;
    else
      h_M[i] = 0.5;
  }

  for(int i = 0;i < n * k; ++i)
  {
    if(i % 2 == 0)
      h_N[i] = 0.5;
    else
      h_N[i] = 1.0;
  }

  // Copy data from CPU to GPU
  cudaMemcpy(d_M, h_M, sizeM, cudaMemcpyHostToDevice);
  cudaMemcpy(d_N, h_N, sizeN, cudaMemcpyHostToDevice);

  // Timing records 
  cudaEvent_t start,stop;
  cudaEventCreate(&start);
  cudaEventCreate(&stop);
  cudaEventRecord(start,0);

  // Launch kernel 定义grid&block
  dim3 grid((int)ceil(k*1.0 / TILE_WIDTH), (int)ceil(m*1.0/ TILE_WIDTH));
  dim3 block(TILE_WIDTH, TILE_WIDTH);
  
  int nIter = 5;
#ifdef USE_CUBLAS
  cublasHandle_t handle;
  cublasCreate(&handle);
#endif
  const float alpha = 1.0f;
  const float beta  = 0.0f;
  for (int j = 0; j < nIter; j++) {
    //matrixMulCPU(reference, h_M, h_N, m, k, n);
    MatrixMulKernel<<<grid, block>>>(d_M, d_N, d_P, m);
    //MatrixMulSharedMemKernel<<<grid, block>>>(d_M, d_N, d_P, m, n);
    //cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, d_N, n, d_M, k, &beta, d_P, n);
  }

  cudaEventRecord(stop, 0);
  cudaEventSynchronize(stop);
  float msecPerMatrixMul;
  cudaEventElapsedTime(&msecPerMatrixMul, start, stop);
  msecPerMatrixMul /= nIter;
  printf("Kernel Elpased Time: %.3f ms\n", msecPerMatrixMul);

  // Compute and print the performance
  double flopsPerMatrixMul = 2.0 * (double)m * (double)n * (double)k;
  double gigaFlops = (flopsPerMatrixMul * 1.0e-9f) / (msecPerMatrixMul / 1000.0f);
  printf("Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops\n",
		  gigaFlops,
		  msecPerMatrixMul,
		  flopsPerMatrixMul);

  // Copy data from GPU to CPU 
  cudaMemcpy(h_P, d_P, sizeP, cudaMemcpyDeviceToHost);

  // compute reference solution
  if (check == 1)
  {
    printf("Computing result using host CPU...");
    matrixMulCPU(reference, h_M, h_N, m, k, n);
    printf("done.\n");
    printDiff(reference, h_P, n, m, 100, 1.0e-5f);
  }

  free(h_P);
  free(h_M);
  free(h_N);
  cudaFree(d_P);
  cudaFree(d_M);
  cudaFree(d_N);
#ifdef USE_CUBLAS
  cublasDestroy(handle);
#endif

  return 0;
}

出力(ブロックサイズTILE_WIDTH = 16)は以下の通りです

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab4-5]
└─$ nvcc -arch=compute_35 -L/usr/local/cuda/lib64 -lcublas ./matrix_mul.cu -Wno-deprecated-gpu-targets
./matrix_mul.cu(202): warning #177-D: variable "alpha" was declared but never referenced

./matrix_mul.cu(203): warning #177-D: variable "beta" was declared but never referenced

./matrix_mul.cu(18): warning #177-D: variable "BLOCK_SIZE" was declared but never referenced

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab4-5]
└─$ ./a.out 1 1000
Kernel Elpased Time: 0.802 ms
Performance= 2493.63 GFlop/s, Time= 0.802 msec, Size= 2000000000 Ops
Computing result using host CPU...done.
Listing first 100 Differences > 0.000010...

  Total Errors = 0

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab4-5]
└─$ ./a.out 0 256
Kernel Elpased Time: 0.691 ms
Performance= 48.57 GFlop/s, Time= 0.691 msec, Size= 33554432 Ops

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab4-5]
└─$ ./a.out 0 1024
Kernel Elpased Time: 0.907 ms
Performance= 2367.84 GFlop/s, Time= 0.907 msec, Size= 2147483648 Ops

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab4-5]
└─$ ./a.out 0 2048
Kernel Elpased Time: 5.411 ms
Performance= 3175.05 GFlop/s, Time= 5.411 msec, Size= 17179869184 Ops

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab4-5]
└─$ ./a.out 0 4096
Kernel Elpased Time: 40.593 ms
Performance= 3385.77 GFlop/s, Time= 40.593 msec, Size= 137438953472 Ops

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab4-5]
└─$ ./a.out 0 10000
Kernel Elpased Time: 577.970 ms
Performance= 3460.38 GFlop/s, Time= 577.970 msec, Size= 2000000000000 Ops

ブロックサイズを半分(8)にすると、性能も半減します。

Lab5 CUDA:GPUの行列乗算最適化

共有メモリでのGPU行列乗算最適化

インデックス\(aRow\), \(aCol\), \(bRow\), \(bCol\)は、各スレッドがアクセスする行列\(A\)と\(B\)の要素を決定するために使用されます。これらの計算は次のようになります:

\begin{align*}aRow &= BLOCK\_SIZE \cdot by + ty \\aCol &= a – aBegin + tx\\bRow &= \frac{b – bBegin}{wB} + ty \\bCol &= BLOCK\_SIZE \cdot bx + tx\end{align*}

\(BLOCK\_SIZE\): ブロックが処理するタイルのサイズ。
\(wB\): 行列\(B\)の幅。
\(a\): 現在処理中の行列\(A\)のタイルの開始位置。
\(aBegin\): 行列\(A\)の最初のタイルの開始位置。
\(b\): 現在処理中の行列\(B\)のタイルの開始位置。
\(bBegin\): 行列\(B\)の最初のタイルの開始位置。
\(bx\): 水平方向のブロックインデックス。
\(by\): 垂直方向のブロックインデックス。
\(ty\): ブロック内のスレッドの行インデックス。
\(tx\): ブロック内のスレッドの列インデックス。

各スレッドは、行列\(As\)(行列\(A\)の部分行列)の1行と、行列\(Bs\)(行列\(B\)の部分行列)の1列のドット積を計算して、出力行列\(C\)の1つの要素を計算します:
\[Csub = \sum_{k=0}^{BLOCK\_SIZE – 1} As[ty][k] \cdot Bs[k][tx]\]

計算後、各スレッドは結果を出力行列\(C\)に書き込みます。書き込み用のインデックスは以下で計算されます:
\begin{align*}cRow &= BLOCK\_SIZE \cdot by + ty \\cCol &= BLOCK\_SIZE \cdot bx + tx\end{align*}

境界チェックにより、範囲外のインデックスが無視されることを確認します:
\[\text{if } (cRow < wA) \text{ and } (cCol < wB), \text{ then store } Csub.\]
結果は以下のように格納されます:
\[C[cRow \cdot wB + cCol] = Csub\]

// #define USE_CUBLAS

#include <iostream>
#include <cstdio>
#include <cuda_runtime.h>
#ifdef USE_CUBLAS
#include <cublas_v2.h>
#endif
#include <device_launch_parameters.h>
#include <cmath>
using namespace std;

const int TILE_WIDTH = 16;	// 定义块block大小

/////////
// Matrix multiplication with shared memory (CUDA Kernel) on the device: C = A * B
/////////
const int BLOCK_SIZE = TILE_WIDTH;
__global__ void MatrixMulSharedMemKernel(float *A,
    float *B, float *C, int wA,
    int wB) {
  // Block index
  int bx = blockIdx.x;
  int by = blockIdx.y;

  // Thread index
  int tx = threadIdx.x;
  int ty = threadIdx.y;

  // Index of the first sub-matrix of A processed by the block
  int aBegin = wA * BLOCK_SIZE * by;

  // Index of the last sub-matrix of A processed by the block
  int aEnd   = aBegin + wA - 1;

  // Step size used to iterate through the sub-matrices of A
  int aStep  = BLOCK_SIZE;

  // Index of the first sub-matrix of B processed by the block
  int bBegin = BLOCK_SIZE * bx;

  // Step size used to iterate through the sub-matrices of B
  int bStep  = BLOCK_SIZE * wB;

  // Csub is used to store the element of the block sub-matrix
  // that is computed by the thread
  float Csub = 0;

  // Loop over all the sub-matrices of A and B
  // required to compute the block sub-matrix
  for (int a = aBegin, b = bBegin;
       a < aEnd;
       a += aStep, b += bStep) {
    // Declaration of the shared memory array As used to
    // store the sub-matrix of A
    __shared__ float As[BLOCK_SIZE][BLOCK_SIZE];

    // Declaration of the shared memory array Bs used to
    // store the sub-matrix of B
    __shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE];

    // Load the matrices from device memory
    // to shared memory; each **thread** loads
    // one element of each matrix
    // --- TO DO :Load the elements of the sub-matrix of A into As ---
    // ---        Load the elements of the sub-matrix of B into Bs ---
    // NOTE: Ensure that the thread indices do not exceed the matrix dimensions to avoid out-of-bounds access.
    //       Use boundary checks to load valid elements into shared memory, and set invalid elements to 0.0f




    // Synchronize to make sure the matrices are loaded
    __syncthreads();

    // Multiply the two matrices together;
    // each thread computes one element
    // of the block sub-matrix
#pragma unroll
    // --- TO DO :Implement the matrix multiplication using the sub-matrices As and Bs ---




    // Synchronize to make sure that the preceding
    // computation is done before loading two new
    // sub-matrices of A and B in the next iteration
    __syncthreads();
  }

  // Write the block sub-matrix to device memory;
  // each thread writes one element
  int c = wB * BLOCK_SIZE * by + BLOCK_SIZE * bx;
  // --- TO DO :Store the computed Csub result into matrix C ---
  // NOTE: Ensure that the thread indices "c" do not exceed the matrix dimensions to avoid out-of-bounds access.
  //       Use boundary checks to write valid elements to the output matrix C.



}


//! For square matrices only
__global__ void MatrixMulKernel(float* d_M, float* d_N, float* d_P, int width)
{
  // Calculate the row index of the P element and M
  // *** TO DO: Compute the row index for the current thread ***
  // int row = ...;
  int row = blockIdx.y * blockDim.y + threadIdx.y;
  // Calculate the column index of the P element and N
  // *** TO DO: Compute the column index for the current thread ***
  // int col = ...;
  int col = blockIdx.x * blockDim.x + threadIdx.x;
  // Ensure the thread is within bounds
  if ( (row < width) && (col < width) ) {
    float pValue = 0.0;

    // Each thread computes one element of the matrix
    // *** TO DO: Implement the matrix multiplication for a single element ***
    for (int k = 0; k < width; k++)
        pValue += d_M[row * width + k] * d_N[k * width + col];
    // Store the computed value into the output matrix
    // *** TO DO: Write the computed value to the correct position in d_P ***
    // d_P[row * width + col] = ...;
    d_P[row * width + col] = pValue;
  }
}

////////////////////////////////////////////////////////////////////////////////
//! Compute reference data set matrix multiply on CPU
//! C = A * B
//! @param C          reference data, computed but preallocated
//! @param A          matrix A as provided to device
//! @param B          matrix B as provided to device
//! @param hA         height of matrix A
//! @param wA         width of matrix A
//! @param wB         width of matrix B
////////////////////////////////////////////////////////////////////////////////
void
matrixMulCPU(float *C, const float *A, const float *B, unsigned int hA, unsigned int wA, unsigned int wB)
{
    for (unsigned int i = 0; i < hA; ++i)
        for (unsigned int j = 0; j < wB; ++j)
        {
            double sum = 0;

            for (unsigned int k = 0; k < wA; ++k)
            {
                double a = A[i * wA + k];
                double b = B[k * wB + j];
                sum += a * b;
            }

            C[i * wB + j] = (float)sum;
        }
}

void printDiff(float *data1, float *data2, int width, int height, int iListLength, float fListTol)
{
    printf("Listing first %d Differences > %.6f...\n", iListLength, fListTol);
    int i,j,k;
    int error_count=0;

    for (j = 0; j < height; j++)
    {
        for (i = 0; i < width; i++)
        {
            k = j * width + i;
            float fDiff = fabs(data1[k] - data2[k]);

            if (fDiff > fListTol)
            {
                if (error_count < iListLength)
                {
                    printf("    Loc(%d,%d)\tCPU=%.5f\tGPU=%.5f\tDiff=%.6f\n", i, j, data1[k], data2[k], fDiff);
                }

                error_count++;
            }
        }
    }

    printf(" \n  Total Errors = %d\n", error_count);
}

void getArg(int argc, char* argv[], int &size, int &check)
{
  if (argc != 3)
  {
    cerr << "Usage: " << argv[0] << " <check_enable> <size>\n";
    cerr << "\tcheck_enable: 1 to enable result checking\n";
    cerr << "\tsize: size of the matrix\n";
    exit(1);
  }

  int val1, val2;
  try
  {
    val1 = stoi(argv[1]);
    val2 = stoi(argv[2]);
  }
  catch (const invalid_argument& e)
  {
    cerr << "ERROR: parameters should be integer\n";
    exit(1);
  }

  check = val1;
  size = val2;
}



int main(int argc, char* argv[])
{
  int size, check;
  getArg(argc, argv, size, check);

  int m = size, n = size, k = size;
  
  // 声明存放在GPU上的数组
  float *h_M, *h_N, *d_M, *d_N;
  float *h_P, *d_P;
  
  size_t sizeM = m * k * sizeof(float);
  size_t sizeN = k * n * sizeof(float);
  size_t sizeP = m * n * sizeof(float);

  // Allocate host memory
  h_M = (float*) malloc(sizeM);
  h_N = (float*) malloc(sizeN);
  h_P = (float*) malloc(sizeP);
  float *reference = (float *)malloc(sizeP);

  // Allocate device memory
  cudaMalloc(&d_M, sizeM);
  cudaMalloc(&d_N, sizeN);
  cudaMalloc(&d_P, sizeP);

  // Init data 
  for(int i = 0; i < m * n; ++i)
  {
    if(i % 2 == 0)
      h_M[i] = 1.0;
    else
      h_M[i] = 0.5;
  }

  for(int i = 0;i < n * k; ++i)
  {
    if(i % 2 == 0)
      h_N[i] = 0.5;
    else
      h_N[i] = 1.0;
  }

  // Copy data from CPU to GPU
  cudaMemcpy(d_M, h_M, sizeM, cudaMemcpyHostToDevice);
  cudaMemcpy(d_N, h_N, sizeN, cudaMemcpyHostToDevice);

  // Timing records 
  cudaEvent_t start,stop;
  cudaEventCreate(&start);
  cudaEventCreate(&stop);
  cudaEventRecord(start,0);

  // Launch kernel 定义grid&block
  dim3 grid((int)ceil(k*1.0 / TILE_WIDTH), (int)ceil(m*1.0/ TILE_WIDTH));
  dim3 block(TILE_WIDTH, TILE_WIDTH);
  
  int nIter = 5;
#ifdef USE_CUBLAS
  cublasHandle_t handle;
  cublasCreate(&handle);
#endif
  const float alpha = 1.0f;
  const float beta  = 0.0f;
  for (int j = 0; j < nIter; j++) {
    //matrixMulCPU(reference, h_M, h_N, m, k, n);
    //MatrixMulKernel<<<grid, block>>>(d_M, d_N, d_P, m);
    MatrixMulSharedMemKernel<<<grid, block>>>(d_M, d_N, d_P, m, n);
    //cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, d_N, n, d_M, k, &beta, d_P, n);
  }

  cudaEventRecord(stop, 0);
  cudaEventSynchronize(stop);
  float msecPerMatrixMul;
  cudaEventElapsedTime(&msecPerMatrixMul, start, stop);
  msecPerMatrixMul /= nIter;
  printf("Kernel Elpased Time: %.3f ms\n", msecPerMatrixMul);

  // Compute and print the performance
  double flopsPerMatrixMul = 2.0 * (double)m * (double)n * (double)k;
  double gigaFlops = (flopsPerMatrixMul * 1.0e-9f) / (msecPerMatrixMul / 1000.0f);
  printf("Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops\n",
		  gigaFlops,
		  msecPerMatrixMul,
		  flopsPerMatrixMul);

  // Copy data from GPU to CPU 
  cudaMemcpy(h_P, d_P, sizeP, cudaMemcpyDeviceToHost);

  // compute reference solution
  if (check == 1)
  {
    printf("Computing result using host CPU...");
    matrixMulCPU(reference, h_M, h_N, m, k, n);
    printf("done.\n");
    printDiff(reference, h_P, n, m, 100, 1.0e-5f);
  }

  free(h_P);
  free(h_M);
  free(h_N);
  cudaFree(d_P);
  cudaFree(d_M);
  cudaFree(d_N);
#ifdef USE_CUBLAS
  cublasDestroy(handle);
#endif

  return 0;
}

出力(ブロックサイズ=16)は以下の通りで、約30%向上しました:

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab4-5]
└─$ nvcc -arch=compute_35 -L/usr/local/cuda/lib64 -lcublas ./matrix_mul.cu -Wno-deprecated-gpu-targets
./matrix_mul.cu(107): warning #177-D: variable "c" was declared but never referenced

./matrix_mul.cu(293): warning #177-D: variable "alpha" was declared but never referenced

./matrix_mul.cu(294): warning #177-D: variable "beta" was declared but never referenced

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab4-5]
└─$ ./a.out 1 1000
Kernel Elpased Time: 0.712 ms
Performance= 2810.55 GFlop/s, Time= 0.712 msec, Size= 2000000000 Ops
Computing result using host CPU...done.
Listing first 100 Differences > 0.000010...

  Total Errors = 0

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab4-5]
└─$ ./a.out 0 256
Kernel Elpased Time: 0.595 ms
Performance= 56.39 GFlop/s, Time= 0.595 msec, Size= 33554432 Ops

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab4-5]
└─$ ./a.out 0 1024
Kernel Elpased Time: 0.726 ms
Performance= 2959.59 GFlop/s, Time= 0.726 msec, Size= 2147483648 Ops

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab4-5]
└─$ ./a.out 0 2048
Kernel Elpased Time: 4.214 ms
Performance= 4076.64 GFlop/s, Time= 4.214 msec, Size= 17179869184 Ops

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab4-5]
└─$ ./a.out 0 4096
Kernel Elpased Time: 31.043 ms
Performance= 4427.37 GFlop/s, Time= 31.043 msec, Size= 137438953472 Ops

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab4-5]
└─$ ./a.out 0 10000
Kernel Elpased Time: 447.665 ms
Performance= 4467.63 GFlop/s, Time= 447.665 msec, Size= 2000000000000 Ops

CUBLASでのGPU行列乗算最適化

マクロ定義を適用するだけでCUBLASを有効化できます。

#define USE_CUBLAS

#include <iostream>
#include <cstdio>
#include <cuda_runtime.h>
#ifdef USE_CUBLAS
#include <cublas_v2.h>
#endif
#include <device_launch_parameters.h>
#include <cmath>
using namespace std;

const int TILE_WIDTH = 16;	// 定义块block大小

/////////
// Matrix multiplication with shared memory (CUDA Kernel) on the device: C = A * B
/////////
const int BLOCK_SIZE = TILE_WIDTH;
__global__ void MatrixMulSharedMemKernel(float *A,
    float *B, float *C, int wA,
    int wB) {
  // Block index
  int bx = blockIdx.x;
  int by = blockIdx.y;

  // Thread index
  int tx = threadIdx.x;
  int ty = threadIdx.y;

  // Index of the first sub-matrix of A processed by the block
  int aBegin = wA * BLOCK_SIZE * by;

  // Index of the last sub-matrix of A processed by the block
  int aEnd   = aBegin + wA - 1;

  // Step size used to iterate through the sub-matrices of A
  int aStep  = BLOCK_SIZE;

  // Index of the first sub-matrix of B processed by the block
  int bBegin = BLOCK_SIZE * bx;

  // Step size used to iterate through the sub-matrices of B
  int bStep  = BLOCK_SIZE * wB;

  // Csub is used to store the element of the block sub-matrix
  // that is computed by the thread
  float Csub = 0;

  // Loop over all the sub-matrices of A and B
  // required to compute the block sub-matrix
  for (int a = aBegin, b = bBegin;
       a < aEnd;
       a += aStep, b += bStep) {
    // Declaration of the shared memory array As used to
    // store the sub-matrix of A
    __shared__ float As[BLOCK_SIZE][BLOCK_SIZE];

    // Declaration of the shared memory array Bs used to
    // store the sub-matrix of B
    __shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE];

    // Load the matrices from device memory
    // to shared memory; each **thread** loads
    // one element of each matrix
    // --- TO DO :Load the elements of the sub-matrix of A into As ---
    // ---        Load the elements of the sub-matrix of B into Bs ---
    // NOTE: Ensure that the thread indices do not exceed the matrix dimensions to avoid out-of-bounds access.
    //       Use boundary checks to load valid elements into shared memory, and set invalid elements to 0.0f
  int aRow = BLOCK_SIZE * by + ty;
  int aCol = a - aBegin + tx;

  int bRow = (b - bBegin) / wB + ty;
  int bCol = BLOCK_SIZE * bx + tx;

  if (aRow < wA && aCol < wA)
      As[ty][tx] = A[aRow * wA + aCol];
  else
      As[ty][tx] = 0.0f;

  if (bRow < wA && bCol < wB)
      Bs[ty][tx] = B[bRow * wB + bCol];
  else
      Bs[ty][tx] = 0.0f;
    // Synchronize to make sure the matrices are loaded
    __syncthreads();

    // Multiply the two matrices together;
    // each thread computes one element
    // of the block sub-matrix
#pragma unroll
    // --- TO DO :Implement the matrix multiplication using the sub-matrices As and Bs ---
  for (int k = 0; k < BLOCK_SIZE; ++k)
      Csub += As[ty][k] * Bs[k][tx];
    // Synchronize to make sure that the preceding
    // computation is done before loading two new
    // sub-matrices of A and B in the next iteration
    __syncthreads();
  }

  // Write the block sub-matrix to device memory;
  // each thread writes one element
  int c = wB * BLOCK_SIZE * by + BLOCK_SIZE * bx;
  // --- TO DO :Store the computed Csub result into matrix C ---
  // NOTE: Ensure that the thread indices "c" do not exceed the matrix dimensions to avoid out-of-bounds access.
  //       Use boundary checks to write valid elements to the output matrix C.
  int cRow = BLOCK_SIZE * by + ty;
  int cCol = BLOCK_SIZE * bx + tx;
  if (cRow < wA && cCol < wB)
      C[cRow * wB + cCol] = Csub;
}


//! For square matrices only
__global__ void MatrixMulKernel(float* d_M, float* d_N, float* d_P, int width)
{
  // Calculate the row index of the P element and M
  // *** TO DO: Compute the row index for the current thread ***
  // int row = ...;
  int row = blockIdx.y * blockDim.y + threadIdx.y;
  // Calculate the column index of the P element and N
  // *** TO DO: Compute the column index for the current thread ***
  // int col = ...;
  int col = blockIdx.x * blockDim.x + threadIdx.x;
  // Ensure the thread is within bounds
  if ( (row < width) && (col < width) ) {
    float pValue = 0.0;

    // Each thread computes one element of the matrix
    // *** TO DO: Implement the matrix multiplication for a single element ***
    for (int k = 0; k < width; k++)
        pValue += d_M[row * width + k] * d_N[k * width + col];
    // Store the computed value into the output matrix
    // *** TO DO: Write the computed value to the correct position in d_P ***
    // d_P[row * width + col] = ...;
    d_P[row * width + col] = pValue;
  }
}

////////////////////////////////////////////////////////////////////////////////
//! Compute reference data set matrix multiply on CPU
//! C = A * B
//! @param C          reference data, computed but preallocated
//! @param A          matrix A as provided to device
//! @param B          matrix B as provided to device
//! @param hA         height of matrix A
//! @param wA         width of matrix A
//! @param wB         width of matrix B
////////////////////////////////////////////////////////////////////////////////
void
matrixMulCPU(float *C, const float *A, const float *B, unsigned int hA, unsigned int wA, unsigned int wB)
{
    for (unsigned int i = 0; i < hA; ++i)
        for (unsigned int j = 0; j < wB; ++j)
        {
            double sum = 0;

            for (unsigned int k = 0; k < wA; ++k)
            {
                double a = A[i * wA + k];
                double b = B[k * wB + j];
                sum += a * b;
            }

            C[i * wB + j] = (float)sum;
        }
}

void printDiff(float *data1, float *data2, int width, int height, int iListLength, float fListTol)
{
    printf("Listing first %d Differences > %.6f...\n", iListLength, fListTol);
    int i,j,k;
    int error_count=0;

    for (j = 0; j < height; j++)
    {
        for (i = 0; i < width; i++)
        {
            k = j * width + i;
            float fDiff = fabs(data1[k] - data2[k]);

            if (fDiff > fListTol)
            {
                if (error_count < iListLength)
                {
                    printf("    Loc(%d,%d)\tCPU=%.5f\tGPU=%.5f\tDiff=%.6f\n", i, j, data1[k], data2[k], fDiff);
                }

                error_count++;
            }
        }
    }

    printf(" \n  Total Errors = %d\n", error_count);
}

void getArg(int argc, char* argv[], int &size, int &check)
{
  if (argc != 3)
  {
    cerr << "Usage: " << argv[0] << " <check_enable> <size>\n";
    cerr << "\tcheck_enable: 1 to enable result checking\n";
    cerr << "\tsize: size of the matrix\n";
    exit(1);
  }

  int val1, val2;
  try
  {
    val1 = stoi(argv[1]);
    val2 = stoi(argv[2]);
  }
  catch (const invalid_argument& e)
  {
    cerr << "ERROR: parameters should be integer\n";
    exit(1);
  }

  check = val1;
  size = val2;
}



int main(int argc, char* argv[])
{
  int size, check;
  getArg(argc, argv, size, check);

  int m = size, n = size, k = size;
  
  // 声明存放在GPU上的数组
  float *h_M, *h_N, *d_M, *d_N;
  float *h_P, *d_P;
  
  size_t sizeM = m * k * sizeof(float);
  size_t sizeN = k * n * sizeof(float);
  size_t sizeP = m * n * sizeof(float);

  // Allocate host memory
  h_M = (float*) malloc(sizeM);
  h_N = (float*) malloc(sizeN);
  h_P = (float*) malloc(sizeP);
  float *reference = (float *)malloc(sizeP);

  // Allocate device memory
  cudaMalloc(&d_M, sizeM);
  cudaMalloc(&d_N, sizeN);
  cudaMalloc(&d_P, sizeP);

  // Init data 
  for(int i = 0; i < m * n; ++i)
  {
    if(i % 2 == 0)
      h_M[i] = 1.0;
    else
      h_M[i] = 0.5;
  }

  for(int i = 0;i < n * k; ++i)
  {
    if(i % 2 == 0)
      h_N[i] = 0.5;
    else
      h_N[i] = 1.0;
  }

  // Copy data from CPU to GPU
  cudaMemcpy(d_M, h_M, sizeM, cudaMemcpyHostToDevice);
  cudaMemcpy(d_N, h_N, sizeN, cudaMemcpyHostToDevice);

  // Timing records 
  cudaEvent_t start,stop;
  cudaEventCreate(&start);
  cudaEventCreate(&stop);
  cudaEventRecord(start,0);

  // Launch kernel 定义grid&block
  dim3 grid((int)ceil(k*1.0 / TILE_WIDTH), (int)ceil(m*1.0/ TILE_WIDTH));
  dim3 block(TILE_WIDTH, TILE_WIDTH);
  
  int nIter = 5;
#ifdef USE_CUBLAS
  cublasHandle_t handle;
  cublasCreate(&handle);
#endif
  const float alpha = 1.0f;
  const float beta  = 0.0f;
  for (int j = 0; j < nIter; j++) {
    //matrixMulCPU(reference, h_M, h_N, m, k, n);
    //MatrixMulKernel<<<grid, block>>>(d_M, d_N, d_P, m);
    //MatrixMulSharedMemKernel<<<grid, block>>>(d_M, d_N, d_P, m, n);
    cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, d_N, n, d_M, k, &beta, d_P, n);
  }

  cudaEventRecord(stop, 0);
  cudaEventSynchronize(stop);
  float msecPerMatrixMul;
  cudaEventElapsedTime(&msecPerMatrixMul, start, stop);
  msecPerMatrixMul /= nIter;
  printf("Kernel Elpased Time: %.3f ms\n", msecPerMatrixMul);

  // Compute and print the performance
  double flopsPerMatrixMul = 2.0 * (double)m * (double)n * (double)k;
  double gigaFlops = (flopsPerMatrixMul * 1.0e-9f) / (msecPerMatrixMul / 1000.0f);
  printf("Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops\n",
		  gigaFlops,
		  msecPerMatrixMul,
		  flopsPerMatrixMul);

  // Copy data from GPU to CPU 
  cudaMemcpy(h_P, d_P, sizeP, cudaMemcpyDeviceToHost);

  // compute reference solution
  if (check == 1)
  {
    printf("Computing result using host CPU...");
    matrixMulCPU(reference, h_M, h_N, m, k, n);
    printf("done.\n");
    printDiff(reference, h_P, n, m, 100, 1.0e-5f);
  }

  free(h_P);
  free(h_M);
  free(h_N);
  cudaFree(d_P);
  cudaFree(d_M);
  cudaFree(d_N);
#ifdef USE_CUBLAS
  cublasDestroy(handle);
#endif

  return 0;
}

結果は以下の通りで、H100で約1000%の性能向上が得られました。

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab4-5]
└─$ ./a.out 0 10000
Kernel Elpased Time: 43.667 ms
Performance= 45800.76 GFlop/s, Time= 43.667 msec, Size= 2000000000000 Ops

Lab6 LLaMAの最適化

まずは、make run

CPU Baseline

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./run stories15M.bin
One day, a little girl named Amy found a very ugly teaspoon. It was so ugly, but Amy liked it. She decided to use it to eat her snack.
As she ate, her friend Tom came to play. Tom saw the ugly teaspoon and laughed. "What a silly teaspoon!" he said. "I have never seen it before." Amy was sad. She did not want Tom to make fun of her ugly teaspoon.
Amy had an idea. She used the ugly teaspoon to bake a sweet cookie. Tom saw the cookie and said, "I don't like it!" Amy said, "It's okay, we can share." Tom smiled, and they both enjoyed the yummy cookie. They laughed and played together, happy that the ugly teaspoon and cookie were enjoyed.
achieved tok/s: 168.819188

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./run stories15M.bin
Once upon a time, there was a little girl named Lily. She loved to paint. One day, she found a dull pencil on the ground. Lily wanted to paint a big, beautiful picture.
Lily took her paint and brush and started to paint. She painted a red apple, a blue banana, and a green frog. She was very happy with her painting. Then, she showed her painting to her friend, Tim.
"Look at my picture!" Lily said to Tim. Tim looked at the picture and smiled. "That's a great painting, Lily!" he said. They both laughed and played with the dull pencil and painting more pictures. And they lived happily ever after.
achieved tok/s: 170.305677

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./run stories15M.bin
Once upon a time, there was a little girl named Lily. She loved to play outside in the sunshine. One day, she found a treasure map. The map showed a big X in the desert where treasure was buried.
Lily decided to follow the map and see where the treasure was. She walked and walked for a long time until she found the big X. She dug and dug until she found a spot where the treasure was buried.
But when she tried to go back home, she found that her shoelaces were untied. She asked for help, but no one could tie them. She felt sad and wished she had returned home earlier.
The moral of the story is that sometimes we need to make a choice to find a solution, even if it's hard. And sometimes, trying something new can lead to trouble.
achieved tok/s: 168.362627

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./run stories15M.bin
One day, a thin cat named Tom found a book. He liked the book because it had a lot of words. Tom wanted to read the book, but it was closed. He tried to open it, but it was hard. So, Tom had a plan.
Tom saw a boy named Tim. Tom asked Tim, "Can you help me open the book?" Tim said, "Yes, I can help!" Tim opened the book for Tom. They both looked at the words inside.
Tim and Tom were having fun. They wanted to make a spell to open the book. They found a spell and gave it to Tom. They said the words together. Suddenly, the book opened by itself! Inside the book, they found a magic stone. Tom and Tim were very happy. They used the magic stone to make their wishes come true.
achieved tok/s: 167.468720

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./run stories15M.bin
Once upon a time, there was a little girl named Lily. She loved to play with her toys and draw pictures. One day, her mommy asked her to take a bath and turn on the faucet. Lily didn't want to, but her mommy said she would feel better after.
After her bath, Lily went to bed. But she couldn't sleep because she was scared of the dark. She called out for her mommy, "Mommy, I'm scared!" Her mommy came in and turned on the faucet. Lily felt safe again and fell asleep.
The next day, Lily's mommy asked her to clean up her toys. But Lily didn't want to. Her mommy said, "If you don't clean up your toys, you won't be able to find them later." Lily didn't want that, so she started cleaning up. But soon, she realized that it was much easier to find her toys when she was ready to go to bed.
achieved tok/s: 168.382353

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./run stories42M.bin
Once upon a time, there was a little girl named Lily. One day, she saw a candle on the table and she wanted to touch it. But her mom said, "No, Lily! That candle is hot and it can hurt you." Lily listened to her mom and didn't touch the candle.
Later that night, Lily was scared of the dark. Her mom said, "Don't be afraid, Lily. I am here with you." Lily felt better and went to sleep. The next day, Lily went to the park with her friend Timmy. They played on the swings and went down the slide.
Suddenly, Timmy pinched Lily's arm. "Ouch!" she cried. Timmy said, "I'm sorry, Lily. I didn't mean to hurt you." Lily forgave Timmy and they continued to play. When it was time to go home, Lily's mom gave her a soft teddy bear as a gift. Lily was happy and went to bed with a smile on her face.
achieved tok/s: 57.266603

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./run stories42M.bin
Once upon a time, there was a little bird named Tweetie. Tweetie had beautiful feathers that shone in the sun. One day, Tweetie was flying in the sky when she saw a little girl named Lily who was lost.
Lily was crying and didn't know where to go. Tweetie felt compassionate and flew down to her. "Don't worry, little girl. I'll help you find your way home," said Tweetie.
Lily was so happy to have a friend like Tweetie. She told Tweetie a joke and they laughed together. "Why did the tomato turn red?" asked Lily. "I don't know, why?" replied Tweetie. "Because it saw the salad dressing!" said Lily, laughing again.
Tweetie was so happy to help Lily find her way home. She flew away feeling proud of herself. But then, something unexpected happened. A big gust of wind blew Tweetie away from the girl and she landed on a branch far away from her. Tweetie was scared and didn't know what to do. She
achieved tok/s: 56.679262

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./run stories42M.bin
Once upon a time, there was a big bird named Ostrich. Ostrich was very tall and had long legs. Ostrich liked to run and jump in the big blue sky.
One day, Ostrich met a little girl named Lily. Lily said, "Hi Ostrich, can you count how many times I can run in the sky?" Ostrich said, "Sure, I can count to ten!"
Lily ran and ran and ran, but she tripped and fell. Ostrich ran over to her and said, "Are you okay?" Lily said, "Yes, I'm okay. I'm okay, but I need to be careful."
Suddenly, a big storm came and the sky turned dark. Ostrich tried to run away, but his big legs were not strong enough to withstand the storm. The storm was too strong and Ostrich couldn't stay safe. In the end, Ostrich got hurt and couldn't run anymore. Lily was very sad and wished she could have done something to help Ostrich.
achieved tok/s: 55.766493

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./run stories42M.bin
Once upon a time, there was a little girl named Lily. She had a furry teddy bear that she loved very much. One day, Lily was playing with her teddy bear and she accidentally split a glass of juice on the floor. She tried to clean it up with a napkin, but it made a bigger mess.
Lily's mom saw what happened and said, "Lily, you need to be careful with the table. You could hurt yourself or break something."
Lily felt sad because she didn't want to make a mess. She asked her mom, "What can I do to clean it up?"
Her mom said, "You can use a plastic cloth to wipe it up. That way, you won't get hurt and the mess will go away."
Lily learned that it's important to be careful and clean up after herself. She also learned that accidents happen and it's okay to ask for help. From that day on, she made sure to be extra careful and never split anything again.
achieved tok/s: 56.637168

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./run stories42M.bin
Once upon a time, there was a small boy named Timmy. Timmy had a pet turtle named Timmy. Timmy loved Timmy very much and always made sure to feed him food and water.
One day, Timmy's mom bought him a microscope. Timmy was very curious about it and wanted to look at everything through it. He showed it to his turtle and asked, "Do you want to look at tiny things with my microscope?"
But Timmy was a slow turtle and didn't want to wait for him to get there. He went to his mom and said, "I want to eat some leaves from the tree outside." His mom gave him some leaves and he looked at them with the microscope. He was very happy to see tiny bugs and plants up close.
From that day on, Timmy loved to look at things with his microscope and feed them with his food. And he always remembered to be slow and patient with his turtle. The end.
achieved tok/s: 55.977710

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./run stories110M.bin
Once upon a time, there was a little girl named Lily. She loved playing with her toys, especially her robot. The robot was very lively and could move and talk just like a real person.
One day, Lily's friend came over to play. Her friend wanted to play with the robot, but Lily didn't want to share. "No, you can't play with my robot!" Lily said. Her friend felt sad and left.
Later that day, Lily realized that she was being mean to her friend. She decided to go find her and say sorry. When she found her friend, she said, "I'm sorry for not sharing my robot. Will you forgive me?" Her friend smiled and said, "Yes, I forgive you."
Lily learned that sharing is important and that saying sorry can make things better. From that day on, she always shared her toys with her friends.
achieved tok/s: 21.109579

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./run stories110M.bin
Once upon a time, there was a little girl named Lily. She loved to play outside in the big, green field behind her house. One day, she found a huge pile of oats in the field. She wanted to use them to make oatmeal for breakfast.
Lily went back to her house to get some oats. She was so excited to use them to make her breakfast. But when she came back to the field, she saw that someone had already used all of the oats to make a big pile. Lily was very sad because she really wanted to make oatmeal.
Suddenly, Lily heard a rustling in the bushes. She went to investigate and found a little bunny eating all of the oats! Lily was surprised but happy to see the bunny. She decided to share her oats with the bunny and they became friends. From then on, Lily and the bunny would play together in the big, green field and make oatmeal for breakfast.
achieved tok/s: 21.013133

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./run stories110M.bin
Once upon a time, there was a playful dog named Spot. Spot loved to play with his friends. One day, Spot saw a big jar of sugar. He wanted to eat it, but he knew he should ask first.
Spot went to his friend, Cat. He asked, "Can I have some sugar?" Cat said, "No, Spot. Sugar is not good for dogs." Spot was sad, but he did not refuse to listen to Cat.
Spot went to his friend, Bird. He asked, "Can I have some sugar?" Bird said, "No, Spot. Sugar is not good for dogs." Spot was sad, but he did not refuse to listen to Bird. He went to his friend, Fish. Spot asked, "Can I have some sugar?" Fish said, "No, Spot. Sugar is not good for fish." Spot was sad, but he listened to Fish.
One day, Spot saw a big bag of sugar. He thought, "I want to eat some sugar!" But he remembered what his friends said. Spot decided to play with his friends instead. They all had a fun day together. Spot was happy he
achieved tok/s: 21.062195

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./run stories110M.bin
Once upon a time, there was a little boy named Timmy. Timmy loved to eat cookies. One day, his mom made him a yummy cookie. Timmy was so happy and took a big bite. Suddenly, he saw a worm in his cookie!
"Ew, a worm!" Timmy yelled.
"Don't worry, Timmy. Just cut the worm out," said his mom.
Timmy carefully cut out the worm and threw the yummy cookie away. From then on, he always checked his cookies before taking a bite.
achieved tok/s: 21.192053

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./run stories110M.bin
Once upon a time, there was a little boy named Tim. Tim liked to play with his toys and make a big mess. One day, Tim's mom said, "Tim, it's time to clean up your messy room."
Tim looked at his mom and said, "Okay, Mom, I will clean up my room." Tim started to pick up his toys and put them away. While he was cleaning, he found a big, soft teddy bear. Tim gave the teddy bear a big squeeze and said, "I love you, teddy bear."
Tim's mom came back into the room and saw that Tim had cleaned up his messy room. She was very happy and said, "Good job, Tim! Now you can play with your toys again." Tim smiled and hugged his teddy bear. He knew that a clean room was a happy place to be.
achieved tok/s: 21.130537
On an Intel® Core™ i9-12900Kstories15Mstories42Mstories110M
Output rate (tokens/s)
Baseline
168.667713056.465447221.1014994

CPU向け最適化

OpenMPを起用には、次のものを実行。

gcc -O3 -o orun run.c -lm -fopenmp

AVXを使うには、matmulを変更。

void matmul(float* xout, float* x, float* w, int n, int d) {
    int i;
    #pragma omp parallel for private(i)
    for (i = 0; i < d; ++i) {
        __m256 sum_vec = _mm256_setzero_ps();
        for (int j = 0; j < n; j += 8) {
            __m256 w_vec = _mm256_loadu_ps(&w[i * n + j]);
            __m256 x_vec = _mm256_loadu_ps(&x[j]);
            sum_vec = _mm256_fmadd_ps(w_vec, x_vec, sum_vec);
        }
        float temp[8];
        _mm256_storeu_ps(temp, sum_vec);
        xout[i] = temp[0] + temp[1] + temp[2] + temp[3] + temp[4] + temp[5] + temp[6] + temp[7];
    }
}

次のようにコンパイル。

gcc -O3 -o arun run.c -mavx -march=native -lm
gcc -O3 -o aorun run.c -mavx -march=native -lm -fopenmp
┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./arun stories15M.bin
Once upon a time, there was a little girl named Lily. She loved to play outside and pick flowers. One day, she saw a cop car drive by and it made her feel excited.
"Mommy, what is that car doing?" Lily asked.
"It's a cop, sweetie. He helps keep us safe," her mom replied.
As they continued their walk, they saw a sign that said "Beware of Coco". Lily didn't understand what that meant, but she thought it sounded like fun.
Later that day, Lily's mom took her to the park. They saw a big sandcastle that was as tall as Lily! She was so excited that she wanted to stay close to her. Her mom reminded her that it's important to always stay close to the ones you trust and to never wander off alone. Lily understood and promised to always stay close to her mom.
achieved tok/s: 318.611987

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./arun stories42M.bin
amamitsu@Amamitsu:~/llama2$ ./run stories42M.bin
One day, a boy named Tom went to the park. He saw a big tree with dry leaves. Tom wanted to cut the dry leaves and make a fun toy. He took a small saw from his dad's tool box. Tom started to cut the dry leaves.
A big bird came and sat on the tree. The bird said, "Hello, Tom! What are you doing?" Tom said, "I am cutting the dry leaves to make a toy." The bird was happy and flew away. Tom kept cutting the dry leaves.
Suddenly, the wind blew very hard. The dry leaves flew in all directions. They hit Tom's dad on the head. His dad said, "Ouch!" Tom was sad. He didn't mean to make his dad fall. Now, the fun toy was gone and Tom was hurt.
achieved tok/s: 120.836055

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./arun stories110M.bin
Once upon a time, there was a little girl named Lily. One day, she went to the park with her mom. She saw a big, red balloon and wanted to touch it. Her mom said it was too high, but Lily didn't listen. She climbed up and touched it. It was so fun!
Suddenly, a bird flew by and scared Lily. She lost her balance and fell down. Her mom rushed to her side and asked if she was okay. Lily was a little scared, but her mom hugged her and said she was okay.
After that, Lily and her mom went to get some ice cream. Lily chose a mild flavor that she had never tried before. She said to her mom, "This ice cream is yummy!" Her mom smiled and said, "I'm glad you like it, Lily."
achieved tok/s: 45.656755
On an Intel® Core™ i9-12900Kstories15Mstories42Mstories110M
Output rate (tokens/s)
Baseline
168.667713056.465447221.1014994
Output rate (tokens/s)
With OpenMP
432.926829177.80938857.982319
Output rate (tokens/s)
With AVX
318.611987120.83605545.656755
Output rate (tokens/s)
With AVX and OpenMP
415.986949145.54670554.567380
AVXを使うと性能がおよそ100%向上します。
OpenMPを使うと性能がおよそ170%向上します。
ですが、AVXとOpenMPの連携効果は予期以下、およそ150%向上。

GPU向け最適化

run.cu
/* Inference for Llama-2 Transformer model in pure C */

#include <stdio.h>
#include <stdlib.h>
#include <ctype.h>
#include <time.h>
#include <math.h>
#include <string.h>
#include <fcntl.h>
//#include <immintrin.h>
#if defined _WIN32
    #include "win.h"
#else
    #include <unistd.h>
    #include <sys/mman.h>
#endif
//REVIEW L-I CUDA Headers
#include <cuda_runtime.h>
#include <cub/cub.cuh>
#include <cublas_v2.h>
//REVIEW END

//REVIEW DEFINE CUDACHECKs
#define cudaCheck(err) __cudaCheck(err, __FILE__, __LINE__)
#define cublasCheck(err) __cublasCheck(err, __FILE__, __LINE__)
inline void __cudaCheck(cudaError_t err, const char* file, int line) {
    if (err != cudaSuccess) {
        std::cerr << "CUDA Error: " << cudaGetErrorString(err) 
                  << " (" << err << ") at " << file << ":" << line << std::endl;
        std::exit(EXIT_FAILURE);
    }
}

inline void __cublasCheck(cublasStatus_t err, const char* file, int line) {
    if (err != CUBLAS_STATUS_SUCCESS) {
        std::cerr << "cuBLAS Error: " << err 
                  << " at " << file << ":" << line << std::endl;
        std::exit(EXIT_FAILURE);
    }
}
//REVIEW END
// ----------------------------------------------------------------------------
// Transformer model

typedef struct {
    int dim; // transformer dimension
    int hidden_dim; // for ffn layers
    int n_layers; // number of layers
    int n_heads; // number of query heads
    int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
    int vocab_size; // vocabulary size, usually 256 (byte-level)
    int seq_len; // max sequence length
} Config;

typedef struct {
    // token embedding table
    float* token_embedding_table;    // (vocab_size, dim)
    // weights for rmsnorms
    float* rms_att_weight; // (layer, dim) rmsnorm weights
    float* rms_ffn_weight; // (layer, dim)
    // weights for matmuls. note dim == n_heads * head_size
    float* wq; // (layer, dim, n_heads * head_size)
    float* wk; // (layer, dim, n_kv_heads * head_size)
    float* wv; // (layer, dim, n_kv_heads * head_size)
    float* wo; // (layer, n_heads * head_size, dim)
    // weights for ffn
    float* w1; // (layer, hidden_dim, dim)
    float* w2; // (layer, dim, hidden_dim)
    float* w3; // (layer, hidden_dim, dim)
    // final rmsnorm
    float* rms_final_weight; // (dim,)
    // (optional) classifier weights for the logits, on the last layer
    float* wcls;
} TransformerWeights;

typedef struct {
    // current wave of activations
    float *x; // activation at current time stamp (dim,)
    float *xb; // same, but inside a residual branch (dim,)
    float *xb2; // an additional buffer just for convenience (dim,)
    float *hb; // buffer for hidden dimension in the ffn (hidden_dim,)
    float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
    float *q; // query (dim,)
    float *k; // key (dim,)
    float *v; // value (dim,)
    float *att; // buffer for scores/attention values (n_heads, seq_len)
//REVIEW logit on GPU
    float *logitsgpu; //output logits GPU
//REVIEW END
    float *logits; // output logits
    // kv cache
    float* key_cache;   // (layer, seq_len, dim)
    float* value_cache; // (layer, seq_len, dim)
} RunState;

typedef struct {
    Config config; // the hyperparameters of the architecture (the blueprint)
    TransformerWeights weights; // the weights of the model
    RunState state; // buffers for the "wave" of activations in the forward pass
    // some more state needed to properly clean up the memory mapping (sigh)
    int fd; // file descriptor for memory mapping
    float* data; // memory mapped data pointer
    ssize_t file_size; // size of the checkpoint file in bytes
} Transformer;

//REVIEW cuBLAS Handle
cublasHandle_t cuBLASHandle = 0;

inline void createCublasHandle() {
    cublasCheck(cublasCreate(&cuBLASHandle));
}

inline void destroyCublasHandle() {
    cublasCheck(cublasDestroy(cuBLASHandle));
}
//REVIEW END
void malloc_run_state(RunState* s, Config* p) {
    // we calloc instead of malloc to keep valgrind happy
    int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
//REVIEW malloc in GPU
    cudaCheck(cudaMalloc((void**)&s->x, p->dim * sizeof(float)));
    cudaCheck(cudaMalloc((void**)&s->xb, p->dim * sizeof(float)));
    cudaCheck(cudaMalloc((void**)&s->xb2, p->dim * sizeof(float)));
    cudaCheck(cudaMalloc((void**)&s->hb, p->hidden_dim * sizeof(float)));
    cudaCheck(cudaMalloc((void**)&s->hb2, p->hidden_dim * sizeof(float)));
    cudaCheck(cudaMalloc((void**)&s->q, p->dim * sizeof(float)));
    cudaCheck(cudaMalloc((void**)&s->key_cache, p->n_layers * p->seq_len * kv_dim * sizeof(float)));
    cudaCheck(cudaMalloc((void**)&s->value_cache, p->n_layers * p->seq_len * kv_dim * sizeof(float)));
    cudaCheck(cudaMalloc((void**)&s->att, p->n_heads * p->seq_len * sizeof(float)));
    cudaCheck(cudaMalloc((void**)&s->logitsgpu, p->vocab_size * sizeof(float)));
//REVIEW END
//FIXME cast to float*
    s->logits = (float*)calloc(p->vocab_size, sizeof(float));
    // ensure all mallocs went fine
//FIXME add logitsgpu check
    if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
     || !s->key_cache || !s->value_cache || !s->att || !s->logits || !s->logitsgpu) {
        fprintf(stderr, "malloc failed!\n");
        exit(EXIT_FAILURE);
    }
}

void free_run_state(RunState* s) {
//REVIEW free in GPU
    cudaCheck(cudaFree(s->x));
    cudaCheck(cudaFree(s->xb));
    cudaCheck(cudaFree(s->xb2));
    cudaCheck(cudaFree(s->hb));
    cudaCheck(cudaFree(s->hb2));
    cudaCheck(cudaFree(s->q));
    cudaCheck(cudaFree(s->att));
    cudaCheck(cudaFree(s->logitsgpu));
    cudaCheck(cudaFree(s->key_cache));
    cudaCheck(cudaFree(s->value_cache));
//REVIEW END
    free(s->logits);
}

void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
    int head_size = p->dim / p->n_heads;
    // make sure the multiplications below are done in 64bit to fit the parameter counts of 13B+ models
    unsigned long long n_layers = p->n_layers;
    w->token_embedding_table = ptr;
    ptr += p->vocab_size * p->dim;
    w->rms_att_weight = ptr;
    ptr += n_layers * p->dim;
    w->wq = ptr;
    ptr += n_layers * p->dim * (p->n_heads * head_size);
    w->wk = ptr;
    ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
    w->wv = ptr;
    ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
    w->wo = ptr;
    ptr += n_layers * (p->n_heads * head_size) * p->dim;
    w->rms_ffn_weight = ptr;
    ptr += n_layers * p->dim;
    w->w1 = ptr;
    ptr += n_layers * p->dim * p->hidden_dim;
    w->w2 = ptr;
    ptr += n_layers * p->hidden_dim * p->dim;
    w->w3 = ptr;
    ptr += n_layers * p->dim * p->hidden_dim;
    w->rms_final_weight = ptr;
    ptr += p->dim;
    ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_real (for RoPE)
    ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_imag (for RoPE)
    w->wcls = shared_weights ? w->token_embedding_table : ptr;
}

void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,
                     int* fd, float** data, ssize_t* file_size) {
    FILE *file = fopen(checkpoint, "rb");
    if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); }
    // read in the config header
    if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); }
    // negative vocab size is hacky way of signaling unshared weights. bit yikes.
    int shared_weights = config->vocab_size > 0 ? 1 : 0;
    config->vocab_size = abs(config->vocab_size);
    // figure out the file size
    fseek(file, 0, SEEK_END); // move file pointer to end of file
    *file_size = ftell(file); // get the file size, in bytes
    fclose(file);
    // memory map the Transformer weights into the data pointer
    *fd = open(checkpoint, O_RDONLY); // open in read only mode
    if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); }
    *data = (float *)mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
    if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); }
//REVIEW Weight copy
    float* weights_ptr = 0;
    // FIXME remove config in the checkpoint file
    cudaCheck(cudaMalloc((void**)&weights_ptr, *file_size - sizeof(Config)));
    cudaCheck(cudaMemcpy(weights_ptr, *data + sizeof(Config)/sizeof(float), *file_size - sizeof(Config), cudaMemcpyHostToDevice));
//REVIEW END
    memory_map_weights(weights, config, weights_ptr, shared_weights);
}

void build_transformer(Transformer *t, char* checkpoint_path) {
    // read in the Config and the Weights from the checkpoint
    read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
    // allocate the RunState buffers
    malloc_run_state(&t->state, &t->config);
}

void free_transformer(Transformer* t) {
    // close the memory mapping
    if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); }
    if (t->fd != -1) { close(t->fd); }
    // free the RunState buffers
//REVIEW remove the weight on GPU
    cudaCheck(cudaFree(t->weights.token_embedding_table));
//REVIEW END
    free_run_state(&t->state);
}

// ----------------------------------------------------------------------------
// neural net blocks; the dynamics of the Transformer
//REVIEW Spread Tensor
int divUp(int a, int b) {
    return (a + b - 1) / b;
}
//REVIEW END
//REVIEW RMSNorm on GPU
__global__ void rmsnorm_cuk(const float* __restrict__ x, const float* __restrict__ weight, float* __restrict__ out, int dim, float eps) {
    extern __shared__ float sdata[]; 
    int tid = threadIdx.x;
    float val = 0.0f;
    if (tid < dim) {
        float xv = x[tid];
        val = xv * xv;
    }
    typedef cub::BlockReduce<float, 1024> BlockReduce;
    __shared__ typename BlockReduce::TempStorage temp_storage;
    float sum = BlockReduce(temp_storage).Sum(val);
    __syncthreads();

    if (tid == 0) {
        sum = sum / dim + eps;
        sum = rsqrtf(sum);
        sdata[0] = sum;
    }
    __syncthreads();

    float norm_coef = sdata[0];
    if (tid < dim) {
        out[tid] = weight[tid] * (x[tid] * norm_coef);
    }
}

void rmsnorm(float* out, const float* x, const float* weight, int dim, cudaStream_t stream=0) {
    int blockSize = 1024;
    int gridSize = 1;
    size_t smem = sizeof(float);
    rmsnorm_cuk<<<gridSize, blockSize, smem, stream>>>(x, weight, out, dim, 1e-5f);
}
//REVIEW END
//REVIEW softmax on GPU
__device__ void softmax_gpu(float* __restrict__ x, int size) {
    int tid = threadIdx.x;
    int step = blockDim.x;

    // find max value (for numerical stability)
    float max_val = tid < size ? x[tid] : 0;
    for (int i = tid + step; i < size; i += step) {
        if (x[i] > max_val) {
            max_val = x[i];
        }
    }
    using BlockReduce = cub::BlockReduce<float, 1024>;
    __shared__ typename BlockReduce::TempStorage temp_storage;
    __shared__ float shared_val;
    max_val = BlockReduce(temp_storage).Reduce(max_val, cub::Max());
    if (threadIdx.x == 0) {
        shared_val = max_val;
    }
    __syncthreads();
    max_val = shared_val;

    // exp and sum
    float sum = 0.0f;
    for (int i = tid; i < size; i += step) {
        x[i] = expf(x[i] - max_val);
        sum += x[i];
    }
    sum = BlockReduce(temp_storage).Sum(sum);
    if (threadIdx.x == 0) {
        shared_val = sum;
    }
    __syncthreads();
    sum = shared_val;

    // normalize
    for (int i = tid; i < size; i += step) {
        x[i] /= sum;
    }
}

void softmax(float* x, int size) {
    // find max value (for numerical stability)
    float max_val = x[0];
    for (int i = 1; i < size; i++) {
        if (x[i] > max_val) {
            max_val = x[i];
        }
    }
    // exp and sum
    float sum = 0.0f;
    for (int i = 0; i < size; i++) {
        x[i] = expf(x[i] - max_val);
        sum += x[i];
    }
    // normalize
    for (int i = 0; i < size; i++) {
        x[i] /= sum;
    }
}
//REVIEW END
//REVIEW matmul on GPU
void matmul(float* xout, float* x, float* w, int n, int d) {
    // W (d,n) @ x (n,) -> xout (d,)
    float alpha = 1.0f;
    float beta = 0.0f;
    cublasSgemv(cuBLASHandle, CUBLAS_OP_T, n, d, &alpha, w, n, x, 1, &beta, xout, 1);
}
//REVIEW END
//REVIEW RoPE on GPU
__global__ void RoPE_cuk(int pos, float *sq, float *sk, int dim, int kv_dim, int head_size) {
    int global_id = blockIdx.x * blockDim.x + threadIdx.x;
    int idx = global_id * 2;
    if (idx >= dim) return;

    extern __shared__ float freq_table[];

    if (threadIdx.x < head_size) {
        float inv_head = 1.0f / (float)head_size;
        float neg_ln_10000 = -__logf(10000.0f);
        // freq_table[head_dim] = exp(-ln(10000)*(head_dim/head_size))
        freq_table[threadIdx.x] = __expf(neg_ln_10000 * threadIdx.x * inv_head);
    }

    __syncthreads();

    int head_dim = idx % head_size;
    float freq = freq_table[head_dim];
    float val = pos * freq;
    float fcr = __cosf(val);
    float fci = __sinf(val);

    int rotn = (idx < kv_dim) ? 2 : 1;
    for (int v = 0; v < rotn; v++) {
        float* vec = (v == 0) ? sq : sk;
        float v0 = vec[idx];
        float v1 = vec[idx+1];
        float rv0 = v0 * fcr - v1 * fci;
        float rv1 = v0 * fci + v1 * fcr;
        vec[idx]   = rv0;
        vec[idx+1] = rv1;
    }
}

void RoPE(int pos, RunState* s, int dim, int kv_dim, int head_size) {
    int threadsPerBlock = 256;
    int halfDim = dim / 2;
    int blocks = (halfDim + threadsPerBlock - 1) / threadsPerBlock;
    size_t shared_mem_bytes = head_size * sizeof(float);
    RoPE_cuk<<<blocks, threadsPerBlock, shared_mem_bytes>>>(pos, s->q, s->k, dim, kv_dim, head_size);
    cudaCheck(cudaGetLastError());
    cudaCheck(cudaDeviceSynchronize());
}
//REVIEW END
//REVIEW MHA on GPU
__global__ void multi_head_attention_cuk(
    int pos, int seq_len, float *sq, float *satt, float *sxb, 
    float *key_cache, float *value_cache, 
    int kv_dim, int kv_mul, int head_size, int loff, float inv_sqrt_head_size) 
{
    int h = blockIdx.x;
    int tid = threadIdx.x;

    extern __shared__ float shared[];
    float* q_shared = shared;
    float* att_shared = q_shared + head_size;

    float* q = sq + h * head_size;
    for (int i = tid; i < head_size; i += blockDim.x) {
        q_shared[i] = q[i];
    }

    __syncthreads();

    float* att = satt + h * seq_len;

    for (int t = tid; t <= pos; t += blockDim.x) {
        float* k = key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
        float score = 0.0f;
        for (int i = 0; i < head_size; i++) {
            score += q_shared[i] * k[i];
        }
        score *= inv_sqrt_head_size;
        att[t] = score;
    }
    __syncthreads();

    softmax_gpu(att, pos + 1);
    __syncthreads();

    for (int t = tid; t <= pos; t += blockDim.x) {
        att_shared[t] = att[t];
    }

    __syncthreads();

    float* xb = sxb + h * head_size;

    for (int i = tid; i < head_size; i += blockDim.x) {
        float val = 0.0f;
        for (int t = 0; t <= pos; t++) {
            float* v = value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
            val += att_shared[t] * v[i];
        }
        xb[i] = val;
    }
}

void multi_head_attention(int pos, Config* p, RunState* s, int kv_dim, int kv_mul, int head_size, int loff)
{
    float inv_sqrt_head_size = 1.0f / sqrtf((float)head_size);
    int grid = p->n_heads;
    int block = 1024; 
    size_t shared_mem_size = head_size * sizeof(float) + p->seq_len * sizeof(float);

    multi_head_attention_cuk <<<grid, block, shared_mem_size>>> (
        pos, p->seq_len, s->q, s->att, s->xb, s->key_cache, s->value_cache, 
        kv_dim, kv_mul, head_size, loff, inv_sqrt_head_size);
}
//REVIEW END
//REVIEW SiLU on GPU
__global__ void SwiGLU_cuk(float *hb, float *hb2, int hidden_dim) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < hidden_dim) {
        float val = hb[i];
        // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
        val *= (1.0f / (1.0f + expf(-val)));
        // elementwise multiply with w3(x)
        val *= hb2[i];
        hb[i] = val;
    }
}

void SwiGLU(RunState *s, int hidden_dim) {
    SwiGLU_cuk<<<divUp(hidden_dim, 1024), 1024>>>(s->hb, s->hb2, hidden_dim);
}
//REVIEW END
//REVIEW R on GPU
__global__ void residual_connection_cuk(float* x, float* xb, int dim) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < dim) {
        x[i] += xb[i];
    }
}
void residual_connection(float *x, float *xb, int dim) {
    residual_connection_cuk<<<divUp(dim, 1024), 1024>>>(x, xb, dim);
}
//REVIEW END
//REVIEW forward on GPU
float* forward(Transformer* transformer, int token, int pos) {

    // a few convenience variables
    Config* p = &transformer->config;
    TransformerWeights* w = &transformer->weights;
    RunState* s = &transformer->state;
    float *x = s->x;
    int dim = p->dim;
    int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
    int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
    int hidden_dim =  p->hidden_dim;
    int head_size = dim / p->n_heads;

    // copy the token embedding into x
    float* content_row = w->token_embedding_table + token * dim;
    cudaCheck(cudaMemcpy(x, content_row, dim*sizeof(*x), cudaMemcpyHostToDevice));

    // forward all the layers
    for(unsigned long long l = 0; l < p->n_layers; l++) {

        // attention rmsnorm
        rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);

        // key and value point to the kv cache
        int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
        s->k = s->key_cache + loff + pos * kv_dim;
        s->v = s->value_cache + loff + pos * kv_dim;

        // qkv matmuls for this position
        matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
        matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
        matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);

        // RoPE relative positional encoding: complex-valued rotate q and k in each head
        RoPE(pos, s, dim, kv_dim, head_size);

        // multihead attention. iterate over all heads
        multi_head_attention(pos, p, s, kv_dim, kv_mul, head_size, loff);

        // final matmul to get the output of the attention
        matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim);

        // residual connection back into x
        residual_connection(x, s->xb2, dim);

        // ffn rmsnorm
        rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);

        // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
        // first calculate self.w1(x) and self.w3(x)
        matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);
        matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);

        // SwiGLU non-linearity
        SwiGLU(s, hidden_dim);

        // final matmul to get the output of the ffn
        matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim);

        // residual connection
        residual_connection(x, s->xb, dim);
    }

    // final rmsnorm
    rmsnorm(x, x, w->rms_final_weight, dim);

    // classifier into logits
    matmul(s->logitsgpu, x, w->wcls, p->dim, p->vocab_size);
    cudaCheck(cudaMemcpy(s->logits, s->logitsgpu, p->vocab_size * sizeof(float), cudaMemcpyDeviceToHost));
    return s->logits;
}
//REVIEW END
// ----------------------------------------------------------------------------
// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens

typedef struct {
    char *str;
    int id;
} TokenIndex;

typedef struct {
    char** vocab;
    float* vocab_scores;
    TokenIndex *sorted_vocab;
    int vocab_size;
    unsigned int max_token_length;
    unsigned char byte_pieces[512]; // stores all single-byte strings
} Tokenizer;

int compare_tokens(const void *a, const void *b) {
    return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
}

void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
    // i should have written the vocab_size into the tokenizer file... sigh
    t->vocab_size = vocab_size;
    // malloc space to hold the scores and the strings
    t->vocab = (char**)malloc(vocab_size * sizeof(char*));
    t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
    t->sorted_vocab = NULL; // initialized lazily
    for (int i = 0; i < 256; i++) {
        t->byte_pieces[i * 2] = (unsigned char)i;
        t->byte_pieces[i * 2 + 1] = '\0';
    }
    // read in the file
    FILE *file = fopen(tokenizer_path, "rb");
    if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
    if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
    int len;
    for (int i = 0; i < vocab_size; i++) {
        if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
        if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
        t->vocab[i] = (char *)malloc(len + 1);
        if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
        t->vocab[i][len] = '\0'; // add the string terminating token
    }
    fclose(file);
}

void free_tokenizer(Tokenizer* t) {
    for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); }
    free(t->vocab);
    free(t->vocab_scores);
    free(t->sorted_vocab);
}

char* decode(Tokenizer* t, int prev_token, int token) {
    char *piece = t->vocab[token];
    // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
    if (prev_token == 1 && piece[0] == ' ') { piece++; }
    // careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
    // parse this and convert and return the actual byte
    unsigned char byte_val;
    if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
        piece = (char*)t->byte_pieces + byte_val * 2;
    }
    return piece;
}

void safe_printf(char *piece) {
    // piece might be a raw byte token, and we only want to print printable chars or whitespace
    // because some of the other bytes can be various control codes, backspace, etc.
    if (piece == NULL) { return; }
    if (piece[0] == '\0') { return; }
    if (piece[1] == '\0') {
        unsigned char byte_val = piece[0];
        if (!(isprint(byte_val) || isspace(byte_val))) {
            return; // bad byte, don't print it
        }
    }
    printf("%s", piece);
}

int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
    // efficiently find the perfect match for str in vocab, return its index or -1 if not found
    TokenIndex tok = { .str = str }; // acts as the key to search for
//FIXME cast to TokenIndex*
    TokenIndex *res = (TokenIndex *)bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
    return res != NULL ? res->id : -1;
}

void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
    // encode the string text (input) into an upper-bound preallocated tokens[] array
    // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
    if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }

    if (t->sorted_vocab == NULL) {
        // lazily malloc and sort the vocabulary
        t->sorted_vocab = (TokenIndex *)malloc(t->vocab_size * sizeof(TokenIndex));
        for (int i = 0; i < t->vocab_size; i++) {
            t->sorted_vocab[i].str = t->vocab[i];
            t->sorted_vocab[i].id = i;
        }
        qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
    }

    // create a temporary buffer that will store merge candidates of always two consecutive tokens
    // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
    char* str_buffer = (char *)malloc((t->max_token_length*2 +1 +2) * sizeof(char));
    size_t str_len = 0;

    // start at 0 tokens
    *n_tokens = 0;

    // add optional BOS (=1) token, if desired
    if (bos) tokens[(*n_tokens)++] = 1;

    // add_dummy_prefix is true by default
    // so prepend a dummy prefix token to the input string, but only if text != ""
    // TODO: pretty sure this isn't correct in the general case but I don't have the
    // energy to read more of the sentencepiece code to figure out what it's doing
    if (text[0] != '\0') {
        int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
        tokens[(*n_tokens)++] = dummy_prefix;
    }

    // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
    // Code point ↔ UTF-8 conversion
    // First code point	Last code point	Byte 1	Byte 2	Byte 3	Byte 4
    // U+0000	U+007F	    0xxxxxxx
    // U+0080	U+07FF	    110xxxxx	10xxxxxx
    // U+0800	U+FFFF	    1110xxxx	10xxxxxx	10xxxxxx
    // U+10000	U+10FFFF    11110xxx	10xxxxxx	10xxxxxx	10xxxxxx

    // process the raw (UTF-8) byte sequence of the input string
    for (char *c = text; *c != '\0'; c++) {

        // reset buffer if the current byte is ASCII or a leading byte
        // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest
        // 0x80 is 10000000
        // in UTF-8, all continuation bytes start with "10" in first two bits
        // so in English this is: "if this byte is not a continuation byte"
        if ((*c & 0xC0) != 0x80) {
            // this byte must be either a leading byte (11...) or an ASCII char (0x...)
            // => reset our location, as we're starting a new UTF-8 codepoint
            str_len = 0;
        }

        // append the current byte to the buffer
        str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line
        str_buffer[str_len] = '\0';

        // while the next character is a continuation byte, continue appending
        // but if there are too many of them, just stop to avoid overruning str_buffer size.
        if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) {
            continue;
        }

        // ok c+1 is not a continuation byte, so we've read in a full codepoint
        int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);

        if (id != -1) {
            // we found this codepoint in vocab, add it as a token
            tokens[(*n_tokens)++] = id;
        } else {
            // byte_fallback encoding: just encode each byte as a token
            // +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
            // so the individual bytes only start at index 3
            for (int i=0; i < str_len; i++) {
                tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
            }
        }
        str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
    }

    // merge the best consecutive pair each iteration, according the scores in vocab_scores
    while (1) {
        float best_score = -1e10;
        int best_id = -1;
        int best_idx = -1;

        for (int i=0; i < (*n_tokens-1); i++) {
            // check if we can merge the pair (tokens[i], tokens[i+1])
            sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
            int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
            if (id != -1 && t->vocab_scores[id] > best_score) {
                // this merge pair exists in vocab! record its score and position
                best_score = t->vocab_scores[id];
                best_id = id;
                best_idx = i;
            }
        }

        if (best_idx == -1) {
            break; // we couldn't find any more pairs to merge, so we're done
        }

        // merge the consecutive pair (best_idx, best_idx+1) into new token best_id
        tokens[best_idx] = best_id;
        // delete token at position best_idx+1, shift the entire sequence back 1
        for (int i = best_idx+1; i < (*n_tokens-1); i++) {
            tokens[i] = tokens[i+1];
        }
        (*n_tokens)--; // token length decreased
    }

    // add optional EOS (=2) token, if desired
    if (eos) tokens[(*n_tokens)++] = 2;

    free(str_buffer);
}

// ----------------------------------------------------------------------------
// The Sampler, which takes logits and returns a sampled token
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling

typedef struct {
    float prob;
    int index;
} ProbIndex; // struct used when sorting probabilities during top-p sampling

typedef struct {
    int vocab_size;
    ProbIndex* probindex; // buffer used in top-p sampling
    float temperature;
    float topp;
    unsigned long long rng_state;
} Sampler;

int sample_argmax(float* probabilities, int n) {
    // return the index that has the highest probability
    int max_i = 0;
    float max_p = probabilities[0];
    for (int i = 1; i < n; i++) {
        if (probabilities[i] > max_p) {
            max_i = i;
            max_p = probabilities[i];
        }
    }
    return max_i;
}

int sample_mult(float* probabilities, int n, float coin) {
    // sample index from probabilities (they must sum to 1!)
    // coin is a random number in [0, 1), usually from random_f32()
    float cdf = 0.0f;
    for (int i = 0; i < n; i++) {
        cdf += probabilities[i];
        if (coin < cdf) {
            return i;
        }
    }
    return n - 1; // in case of rounding errors
}

int compare(const void* a, const void* b) {
    ProbIndex* a_ = (ProbIndex*) a;
    ProbIndex* b_ = (ProbIndex*) b;
    if (a_->prob > b_->prob) return -1;
    if (a_->prob < b_->prob) return 1;
    return 0;
}

int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) {
    // top-p sampling (or "nucleus sampling") samples from the smallest set of
    // tokens that exceed probability topp. This way we never sample tokens that
    // have very low probabilities and are less likely to go "off the rails".
    // coin is a random number in [0, 1), usually from random_f32()

    int n0 = 0;
    // quicksort indices in descending order of probabilities
    // values smaller than (1 - topp) / (n - 1) cannot be part of the result
    // so for efficiency we crop these out as candidates before sorting
    const float cutoff = (1.0f - topp) / (n - 1);
    for (int i = 0; i < n; i++) {
        if (probabilities[i] >= cutoff) {
            probindex[n0].index = i;
            probindex[n0].prob = probabilities[i];
            n0++;
        }
    }
    qsort(probindex, n0, sizeof(ProbIndex), compare);

    // truncate the list where cumulative probability exceeds topp
    float cumulative_prob = 0.0f;
    int last_idx = n0 - 1; // in case of rounding errors consider all elements
    for (int i = 0; i < n0; i++) {
        cumulative_prob += probindex[i].prob;
        if (cumulative_prob > topp) {
            last_idx = i;
            break; // we've exceeded topp by including last_idx
        }
    }

    // sample from the truncated list
    float r = coin * cumulative_prob;
    float cdf = 0.0f;
    for (int i = 0; i <= last_idx; i++) {
        cdf += probindex[i].prob;
        if (r < cdf) {
            return probindex[i].index;
        }
    }
    return probindex[last_idx].index; // in case of rounding errors
}

void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) {
    sampler->vocab_size = vocab_size;
    sampler->temperature = temperature;
    sampler->topp = topp;
    sampler->rng_state = rng_seed;
    // buffer only used with nucleus sampling; may not need but it's ~small
    sampler->probindex = (ProbIndex *)malloc(sampler->vocab_size * sizeof(ProbIndex));
}

void free_sampler(Sampler* sampler) {
    free(sampler->probindex);
}

unsigned int random_u32(unsigned long long *state) {
    // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
    *state ^= *state >> 12;
    *state ^= *state << 25;
    *state ^= *state >> 27;
    return (*state * 0x2545F4914F6CDD1Dull) >> 32;
}
float random_f32(unsigned long long *state) { // random float32 in [0,1)
    return (random_u32(state) >> 8) / 16777216.0f;
}

int sample(Sampler* sampler, float* logits) {
    // sample the token given the logits and some hyperparameters
    int next;
    if (sampler->temperature == 0.0f) {
        // greedy argmax sampling: take the token with the highest probability
        next = sample_argmax(logits, sampler->vocab_size);
    } else {
        // apply the temperature to the logits
        for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }
        // apply softmax to the logits to get the probabilities for next token
        softmax(logits, sampler->vocab_size);
        // flip a (float) coin (this is our source of entropy for sampling)
        float coin = random_f32(&sampler->rng_state);
        // we sample from this distribution to get the next token
        if (sampler->topp <= 0 || sampler->topp >= 1) {
            // simply sample from the predicted probability distribution
            next = sample_mult(logits, sampler->vocab_size, coin);
        } else {
            // top-p (nucleus) sampling, clamping the least likely tokens to zero
            next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin);
        }
    }
    return next;
}

// ----------------------------------------------------------------------------
// utilities: time

long time_in_ms() {
    // return time in milliseconds, for benchmarking the model speed
    struct timespec time;
    clock_gettime(CLOCK_REALTIME, &time);
    return time.tv_sec * 1000 + time.tv_nsec / 1000000;
}

// ----------------------------------------------------------------------------
// generation loop

void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
    char *empty_prompt = "";
    if (prompt == NULL) { prompt = empty_prompt; }

    // encode the (string) prompt into tokens sequence
    int num_prompt_tokens = 0;
    int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS
    encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
    if (num_prompt_tokens < 1) {
        fprintf(stderr, "something is wrong, expected at least 1 prompt token\n");
        exit(EXIT_FAILURE);
    }

    // start the main loop
    long start = 0;  // used to time our code, only initialized after first iteration
    int next;        // will store the next token in the sequence
    int token = prompt_tokens[0]; // kick off with the first token in the prompt
    int pos = 0;     // position in the sequence
    while (pos < steps) {

        // forward the transformer to get logits for the next token
        float* logits = forward(transformer, token, pos);

        // advance the state machine
        if (pos < num_prompt_tokens - 1) {
            // if we are still processing the input prompt, force the next prompt token
            next = prompt_tokens[pos + 1];
        } else {
            // otherwise sample the next token from the logits
            next = sample(sampler, logits);
        }
        pos++;

        // data-dependent terminating condition: the BOS (=1) token delimits sequences
        if (next == 1) { break; }

        // print the token as string, decode it with the Tokenizer object
        char* piece = decode(tokenizer, token, next);
        safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
        fflush(stdout);
        token = next;

        // init the timer here because the first iteration can be slower
        if (start == 0) { start = time_in_ms(); }
    }
    printf("\n");

    // report achieved tok/s (pos-1 because the timer starts after first iteration)
    if (pos > 1) {
        long end = time_in_ms();
        fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
    }

    free(prompt_tokens);
}

void read_stdin(const char* guide, char* buffer, size_t bufsize) {
    // read a line from stdin, up to but not including \n
    printf("%s", guide);
    if (fgets(buffer, bufsize, stdin) != NULL) {
        size_t len = strlen(buffer);
        if (len > 0 && buffer[len - 1] == '\n') {
            buffer[len - 1] = '\0'; // strip newline
        }
    }
}

// ----------------------------------------------------------------------------
// chat loop
// I manually inspected the tokens for a few chat conversations compared to
// python reference and that seemed ok, but this was not thoroughly tested and
// is not safely implemented, it's more a proof of concept atm.

void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
          char *cli_user_prompt, char *cli_system_prompt, int steps) {

    // buffers for reading the system prompt and user prompt from stdin
    // you'll notice they are soomewhat haphazardly and unsafely set atm
    char system_prompt[512];
    char user_prompt[512];
    char rendered_prompt[1152];
    int num_prompt_tokens = 0;
    int* prompt_tokens = (int*)malloc(1152 * sizeof(int));
    int user_idx;

    // start the main loop
    int8_t user_turn = 1; // user starts
    int next;        // will store the next token in the sequence
    int token;       // stores the current token to feed into the transformer
    int prev_token;
    int pos = 0;     // position in the sequence
    while (pos < steps) {

        // when it is the user's turn to contribute tokens to the dialog...
        if (user_turn) {
            // get the (optional) system prompt at position 0
            if (pos == 0) {
                // at position 0, the user can also contribute a system prompt
                if (cli_system_prompt == NULL) {
                    // system prompt was not passed in, attempt to get it from stdin
                    read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
                } else {
                    // system prompt was passed in, use it
                    strcpy(system_prompt, cli_system_prompt);
                }
            }
            // get the user prompt
            if (pos == 0 && cli_user_prompt != NULL) {
                // user prompt for position 0 was passed in, use it
                strcpy(user_prompt, cli_user_prompt);
            } else {
                // otherwise get user prompt from stdin
                read_stdin("User: ", user_prompt, sizeof(user_prompt));
            }
            // render user/system prompts into the Llama 2 Chat schema
            if (pos == 0 && system_prompt[0] != '\0') {
                char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";
                sprintf(rendered_prompt, system_template, system_prompt, user_prompt);
            } else {
                char user_template[] = "[INST] %s [/INST]";
                sprintf(rendered_prompt, user_template, user_prompt);
            }
            // encode the rendered prompt into tokens
            encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
            user_idx = 0; // reset the user index
            user_turn = 0;
            printf("Assistant: ");
        }

        // determine the token to pass into the transformer next
        if (user_idx < num_prompt_tokens) {
            // if we are still processing the input prompt, force the next prompt token
            token = prompt_tokens[user_idx++];
        } else {
            // otherwise use the next token sampled from previous turn
            token = next;
        }
        // EOS (=2) token ends the Assistant turn
        if (token == 2) { user_turn = 1; }

        // forward the transformer to get logits for the next token
        float* logits = forward(transformer, token, pos);
        next = sample(sampler, logits);
        pos++;

        if (user_idx >= num_prompt_tokens && next != 2) {
            // the Assistant is responding, so print its output
            char* piece = decode(tokenizer, token, next);
            safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
            fflush(stdout);
        }
        if (next == 2) { printf("\n"); }
    }
    printf("\n");
    free(prompt_tokens);
}


// ----------------------------------------------------------------------------
// CLI, include only if not testing
#ifndef TESTING

void error_usage() {
    fprintf(stderr, "Usage:   run <checkpoint> [options]\n");
    fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n");
    fprintf(stderr, "Options:\n");
    fprintf(stderr, "  -t <float>  temperature in [0,inf], default 1.0\n");
    fprintf(stderr, "  -p <float>  p value in top-p (nucleus) sampling in [0,1] default 0.9\n");
    fprintf(stderr, "  -s <int>    random seed, default time(NULL)\n");
    fprintf(stderr, "  -n <int>    number of steps to run for, default 256. 0 = max_seq_len\n");
    fprintf(stderr, "  -i <string> input prompt\n");
    fprintf(stderr, "  -z <string> optional path to custom tokenizer\n");
    fprintf(stderr, "  -m <string> mode: generate|chat, default: generate\n");
    fprintf(stderr, "  -y <string> (optional) system prompt in chat mode\n");
    exit(EXIT_FAILURE);
}

int main(int argc, char *argv[]) {
//REVIEW
    createCublasHandle();
//REVIEW END
    // default parameters
    char *checkpoint_path = NULL;  // e.g. out/model.bin
    char *tokenizer_path = "tokenizer.bin";
    float temperature = 1.0f;   // 0.0 = greedy deterministic. 1.0 = original. don't set higher
    float topp = 0.9f;          // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
    int steps = 256;            // number of steps to run for
    char *prompt = NULL;        // prompt string
    unsigned long long rng_seed = 0; // seed rng with time by default
    char *mode = "generate";    // generate|chat
    char *system_prompt = NULL; // the (optional) system prompt to use in chat mode

    // poor man's C argparse so we can override the defaults above from the command line
    if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
    for (int i = 2; i < argc; i+=2) {
        // do some basic validation
        if (i + 1 >= argc) { error_usage(); } // must have arg after flag
        if (argv[i][0] != '-') { error_usage(); } // must start with dash
        if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
        // read in the args
        if (argv[i][1] == 't') { temperature = atof(argv[i + 1]); }
        else if (argv[i][1] == 'p') { topp = atof(argv[i + 1]); }
        else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
        else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
        else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
        else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }
        else if (argv[i][1] == 'm') { mode = argv[i + 1]; }
        else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; }
        else { error_usage(); }
    }

    // parameter validation/overrides
    if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL);
    if (temperature < 0.0) temperature = 0.0;
    if (topp < 0.0 || 1.0 < topp) topp = 0.9;
    if (steps < 0) steps = 0;

    // build the Transformer via the model .bin file
    Transformer transformer;
    build_transformer(&transformer, checkpoint_path);
    if (steps == 0 || steps > transformer.config.seq_len) steps = transformer.config.seq_len; // override to ~max length

    // build the Tokenizer via the tokenizer .bin file
    Tokenizer tokenizer;
    build_tokenizer(&tokenizer, tokenizer_path, transformer.config.vocab_size);

    // build the Sampler
    Sampler sampler;
    build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed);

    // run!
    if (strcmp(mode, "generate") == 0) {
        generate(&transformer, &tokenizer, &sampler, prompt, steps);
    } else if (strcmp(mode, "chat") == 0) {
        chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps);
    } else {
        fprintf(stderr, "unknown mode: %s\n", mode);
        error_usage();
    }

    // memory and file handles cleanup
    free_sampler(&sampler);
    free_tokenizer(&tokenizer);
    free_transformer(&transformer);
//REVIEW
    destroyCublasHandle();
//REVIEW END
    return 0;
}
#endif
 nvcc -O3 -o crun run.cu -lm -lcuda -lcublas
┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./crun stories15M.bin
Once upon a time, there was a dog named Spot. Spot loved to play and run around the yard. One day, Spot found a ball and wanted to play with it. He tried to pick it up with his mouth, but it was too heavy. Spot's owner saw him struggling and said, "Be careful, Spot. Sometimes we might hurt ourselves." Spot was sad because he really wanted to play with the ball.
Spot's owner had an idea. She went to the store and bought some new stuffed animals to play with. Spot was so excited that he forgot about the ball and played with his new friends for hours. When it was time for bed, Spot's owner said, "I'm sorry, Spot. I'll buy you a new ball tomorrow." Spot was happy again and fell asleep dreaming of all the fun he would have with his new friends.
achieved tok/s: 2333.333334

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./crun stories15M.bin
Once upon a time, there was a grumpy little boy named Timmy. He never wanted to play with his friends or eat his lunch. One day, Timmy's mom asked him to help her with the laundry. Timmy didn't want to help because he wanted to play. But then his mom showed him a warm place to put his clothes in the washing machine.
Timmy didn't want to do it, but his mom said it was important to take care of his clothes. So, Timmy put his clothes in the washing machine and turned it on. He was very excited to see his clothes fit again. When the washing machine was done, Timmy's mom said he could go play outside.
But then, Timmy realized that he had left his favorite toy behind. He went to look for it and found it under his bed. His mom said that the toy was in his shirt, but he had forgotten to put it back. Timmy felt sad and realized that he should have listened to his mom. He promised to never forget to put his toys in his shirt again.
The moral of the story is that it's important to listen to your parents
achieved tok/s: 2420.886075

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./crun stories15M.bin
Once upon a time, there was a unique bird named Bobo. Bobo had many colors on his body. He loved to fly and sing. One day, Bobo found a big nest on a tree. The nest was in a yard. Bobo was very happy.
Bobo met a small bug named Titi. Titi said, "Hi, Bobo! Can you help me find my home?" Bobo said, "Yes, I will help you!" They looked for Titi's home together. They found it under a bush.
Bobo and Titi went inside the nest. But the nest was dark. There was only one room. Bobo wanted to sleep, but Titi wanted to see more of the room. Titi was not happy. Bobo was sad. The end.
achieved tok/s: 2411.483253

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./crun stories15M.bin
One day, a little boy named Tim went to the store with his mom. He saw a big, impressive toy. It was a saw. Tim wanted the saw to build a toy house.
His mom said, "Okay, Tim. I will get it for you." She gave Tim some money and he went to the store. Tim put the saw in his bag.
At home, Tim saw a long line of ants. He wanted to take the saw with him. He asked the store man, "How do I get the ant in the line?" The store man said, "You can use the saw to pick up the ant."
Tim tried to pick up the ants. He made a line for his toy house. But, oh no! The toy house fell down. The saw was not a toy. It was a camera! The machine started to make sounds.
Tim was scared. He did not know what to do. Then, a big wind came. The wind blew the camera away. The machine stopped. Tim was safe.
The park man saw what happened. He said, "You were right, the machine fell down. I will help you." Tim was happy.
achieved tok/s: 2361.111111

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./crun stories15M.bin
Once upon a time, there was a little boy named Tim. Tim loved to play outside with his friends. One day, Tim found a big, hard rock. He wanted to show it to his friends, but he was too short to reach it.
Tim's friends came over and asked him to play a game. They said, "Tim, you are a player, so I will lift you up to the rock." Tim felt scared because the other kids were bigger and faster. He wanted to be the one to get the rock, but he was still too small.
Then, Tim had an idea. He asked his mom to help him lift the rock. Together, they lifted the rock and brought it to the other side of the garden. All of Tim's friends were happy to see the rock. They said, "Wow, Tim, you are a great player!" Tim felt proud and not ashamed anymore. He learned that sometimes, you can't do everything you want, but you can still find a way to make it.
achieved tok/s: 2505.747126

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./crun stories42M.bin
Once upon a time, there was a boy named Timmy. Timmy loved to play outside and run around. One day, Timmy's mom said, "Timmy, it's time for your haircut." Timmy didn't want to cut his hair because he was scared.
But his mom said, "Don't worry, Timmy. You can always start at the same time and come back after." Timmy finally agreed and went to the haircut place. The lady cutting his hair was very polite and made him feel comfortable.
After the haircut, Timmy looked in the mirror and saw how handsome he looked. He felt happy and proud of himself. Timmy learned that sometimes, it's scary to start somewhere, but if you start at your favorite favorite spot and are polite to others, good things can happen.
achieved tok/s: 2317.557251

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./crun stories42M.bin
Once upon a time, there was a little girl named Lily. She loved to play with her toys and watch cartoons on TV. One day, Lily was playing with her dolls when she heard a knock on the door. It was her friend, Sarah.
"Hi Lily, do you want to come play outside?" Sarah asked.
"No, I want to stay inside and play with my dolls," Lily replied.
Sarah looked sad and said, "I promise we can still play together. Maybe we can make a tower with these blocks."
Lily thought for a moment and said, "That's a good idea, but I'm not very good at building towers. It might be hard."
Sarah smiled and said, "That's okay. We can do it together. We'll make a tower that's really good. I promise it won't hurt at all."
Lily smiled back and they went outside to play. They built a tower that was much better than the one they had before. Lily was happy she had a friend to play with and Sarah was happy they were both playing together. They promised to play together again soon.
achieved tok/s: 2350.558660

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./crun stories42M.bin
Once upon a time, there was a little girl named Lily. She loved going to the mall with her mommy. One day, while they were walking around, Lily saw a big, hairy dog. She pointed at it and said, "Look, mommy! A dog!"
Suddenly, the dog started barking and running towards them. Lily got scared and hid behind her mommy. But then, the dog stopped barking and licked Lily's face! She was surprised but happy that the dog was friendly.
From that day on, Lily and her mommy went to the mall every weekend to see the hairy dog. They would point and say hello to him and even give him a hug. And the dog became their best friend. The end.
achieved tok/s: 2398.739495

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./crun stories42M.bin
Once upon a time, there was a little girl named Lily. She had a soft, blue mattress in her room that she loved to jump on. One day, Lily's mom told her to be careful and not jump on the mattress anymore because it was old and could break.
Lily didn't listen and kept jumping on the mattress. Suddenly, the mattress broke and Lily fell down. She started to cry and said, "Mommy, my mattress is broken!"
Her mom came running and said, "I told you not to jump on it. Now we have to buy a new one."
Lily was sad because she loved her old mattress. She learned that sometimes we need to listen to our parents and not do things that are not safe.
achieved tok/s: 2250.574712

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./crun stories42M.bin
Once upon a time, there was a little girl named Lily. She loved to watch TV and play with her toys. One day, she accidentally knocked over a candle and it started to burn the curtains. She quickly blew it out and put out the fire.
Lily was scared because she knew she shouldn't have been playing with fire. Her mommy came in and saw what happened. She hugged Lily and told her it was okay, accidents happen.
After that, Lily was more careful and didn't play with fire again. She learned that fire is dangerous and should be used with care. Her mommy was proud of her for being responsible and not being ignorant of the consequences of her actions. From that day on, Lily only watched TV when her mommy was with her and never played with fire again.
achieved tok/s: 2384.765625

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./crun stories110M.bin
Once upon a time, there was a little girl named Lily. She loved to play with her dolls and teddy bears. One day, Lily's mommy gave her a toy phone to play with. Lily was so happy and she played with it all day long.
Suddenly, Lily's teddy bear fell on the ground and got a rough spot on its head. Lily was sad and didn't know what to do. She picked up her toy phone and pretended to call her friend. "Hello, friend! Can you come and help me fix my teddy bear?" she said.
Lily's mommy heard her talking and came to see what was wrong. "What happened, Lily?" she asked. "My teddy bear has a rough spot on its head. Can you help me fix it?" Lily said. Her mommy smiled and said, "Of course, I can help you. Let's go find a cloth to wipe the rough spot." And they went to find a cloth together.
achieved tok/s: 1368.440367

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./crun stories110M.bin
Once upon a time, there was a kind rabbit named Benny. He loved to play with his friends in the forest. One day, Benny found a big truck with a license plate that said "Benny". Benny didn't know what a license was, but he knew it was important.
Benny decided to take the truck to the nearby town to show it to his friends. But when he got there, he saw that the truck was actually a delivery truck for a massage company. The delivery man told Benny that he couldn't show the truck to his friends because it was not polite to be nosy.
Benny felt bad for being nosy and decided to deliver the license plate back to the factory. He hopped all the way there and gave it back to the delivery man. The delivery man was very happy and thanked Benny for being honest.
From that day on, Benny learned that it's important to be respectful and not nosy. He also learned that doing the right thing is always the best choice.
achieved tok/s: 1386.245352

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./crun stories110M.bin
One day, a boy named Tim found a broken toy in his room. It was a small car with a button that said "reverse". Tim did not know what it meant, but he wanted to play with it. He pushed the button, and the car started to move backwards.
Tim's mom saw him playing with the broken toy. She said, "Tim, let's find the label on the car. It shows what it does." They looked around and found a label with the car's name on it. The label had a picture of a blue car.
Tim and his mom went to the park where they found the blue car. Tim pushed the button on the broken car again. This time, the car went forward. Tim's mom said, "Now we know what the label says. The car can reverse!" Tim smiled and played with his car all day at the park.
achieved tok/s: 1438.461539

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./crun stories110M.bin
One day, a little boy named Tim went for a walk. He was an obedient boy who always did what his mom and dad told him to do. While walking, he found a big book on the ground. It was a dictionary! Tim was very happy and picked it up.
Tim carried the dictionary home to show his mom and dad. They were proud of him for being so obedient and finding the dictionary. They told him that the dictionary could help him learn new words.
Every day, Tim would carry the dictionary with him when he went for a walk. He would read it and learn new words. He became a very smart and obedient boy, just like the dictionary helped him.
achieved tok/s: 1475.528700

┌──(amamitsu㉿amamitsu)-[~/Applications/Lab6]
└─$ ./crun stories110M.bin
One day, a boy named Tim went to a new place with his mom. This place was big and had many different things to see. Tim was very happy to be there. He saw a big tree with a cat on it. The cat was scared and did not know how to get down.
Tim said to his mom, "Mom, can we help the cat?" His mom said, "Yes, we can try." They found a long stick and tried to help the cat down. But the stick was too short. The cat was still scared and did not know what to do.
Then, a big bird came and landed on the tree. The bird looked at the cat and started to talk! The bird said, "I can help the cat." The bird flew down and picked up the cat with its beak. The cat was not scared anymore. Tim and his mom were very surprised. They thanked the bird and went home happy.
achieved tok/s: 1331.345162
On an NVIDIA H100 NVLstories15Mstories42Mstories110M
Output rate (tokens/s)
CUBLAS
2406.51217982340.43914861400.004224
お好きならシェアしませんか🤩
  • URLをコピーしました!
  • URLをコピーしました!

コメント

コメントする

目次