我正在做一个矩阵乘法的项目.我已经能够编写C代码,并且能够使用Microsoft Visual Studio 2012编译器为其生成汇编代码.编译器生成的代码如下所示.编译器使用了SSE寄存器,这正是我想要的,但它不是最好的代码.我想优化此代码并将其与C代码内联地编写,但我不理解汇编代码.基本上,汇编代码仅适用于矩阵的一维,下面的代码仅适用于4 x 4矩阵.我该怎么做才能使其适合n * n矩阵大小.
C代码如下所示:
#define MAX_NUM 10
#define MAX_DIM 4
int main () {
float mat_a [] = {1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0};
float mat_b [] = {1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0};
float result [] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
int num_row = 4;
int num_col = 4;
float sum;
for (int i = 0; i < num_row; i++) {
for (int j = 0; j < num_col; j++) {
sum = 0.0;
for (int k = 0; k < num_row; k++) {
sum = sum + mat_a[i * num_col + k] * mat_b[k * num_col + j];
}
*(result + i * num_col + j) = sum;
}
}
return 0;
}
汇编代码如下所示:
; Listing generated by Microsoft (R) Optimizing Compiler Version 17.00.50727.1
TITLE C:\Users\GS\Documents\Visual Studio 2012\Projects\Assembly_InLine\Assembly_InLine\Source.cpp
.686P
.XMM
include listing.inc
.model flat
INCLUDELIB MSVCRTD
INCLUDELIB OLDNAMES
PUBLIC _main
PUBLIC __real@00000000
PUBLIC __real@3f800000
PUBLIC __real@40000000
PUBLIC __real@40400000
PUBLIC __real@40800000
EXTRN @_RTC_CheckStackVars@8:PROC
EXTRN @__security_check_cookie@4:PROC
EXTRN __RTC_InitBase:PROC
EXTRN __RTC_Shutdown:PROC
EXTRN ___security_cookie:DWORD
EXTRN __fltused:DWORD
; COMDAT __real@40800000
CONST SEGMENT
__real@40800000 DD 040800000r ; 4
CONST ENDS
; COMDAT __real@40400000
CONST SEGMENT
__real@40400000 DD 040400000r ; 3
CONST ENDS
; COMDAT __real@40000000
CONST SEGMENT
__real@40000000 DD 040000000r ; 2
CONST ENDS
; COMDAT __real@3f800000
CONST SEGMENT
__real@3f800000 DD 03f800000r ; 1
CONST ENDS
; COMDAT __real@00000000
CONST SEGMENT
__real@00000000 DD 000000000r ; 0
CONST ENDS
; COMDAT rtc$TMZ
rtc$TMZ SEGMENT
__RTC_Shutdown.rtc$TMZ DD FLAT:__RTC_Shutdown
rtc$TMZ ENDS
; COMDAT rtc$IMZ
rtc$IMZ SEGMENT
__RTC_InitBase.rtc$IMZ DD FLAT:__RTC_InitBase
rtc$IMZ ENDS
; Function compile flags: /Odtp /RTCsu /ZI
; COMDAT _main
_TEXT SEGMENT
_k$1 = -288 ; size = 4
_j$2 = -276 ; size = 4
_i$3 = -264 ; size = 4
_sum$= -252 ; size = 4
_num_col$= -240 ; size = 4
_num_row$= -228 ; size = 4
_result$= -216 ; size = 64
_mat_b$= -144 ; size = 64
_mat_a$= -72 ; size = 64
__$ArrayPad$= -4 ; size = 4
_main PROC ; COMDAT
; File c:\users\gs\documents\visual studio 2012\projects\assembly_inline\assembly_inline\source.cpp
; Line 4
push ebp
mov ebp, esp
sub esp, 484 ; 000001e4H
push ebx
push esi
push edi
lea edi, DWORD PTR [ebp-484]
mov ecx, 121 ; 00000079H
mov eax, -858993460 ; ccccccccH
rep stosd
mov eax, DWORD PTR ___security_cookie
xor eax, ebp
mov DWORD PTR __$ArrayPad$[ebp], eax
; Line 5
movss xmm0, DWORD PTR __real@3f800000
movss DWORD PTR _mat_a$[ebp], xmm0
movss xmm0, DWORD PTR __real@40000000
movss DWORD PTR _mat_a$[ebp+4], xmm0
movss xmm0, DWORD PTR __real@40400000
movss DWORD PTR _mat_a$[ebp+8], xmm0
movss xmm0, DWORD PTR __real@40800000
movss DWORD PTR _mat_a$[ebp+12], xmm0
movss xmm0, DWORD PTR __real@3f800000
movss DWORD PTR _mat_a$[ebp+16], xmm0
movss xmm0, DWORD PTR __real@40000000
movss DWORD PTR _mat_a$[ebp+20], xmm0
movss xmm0, DWORD PTR __real@40400000
movss DWORD PTR _mat_a$[ebp+24], xmm0
movss xmm0, DWORD PTR __real@40800000
movss DWORD PTR _mat_a$[ebp+28], xmm0
movss xmm0, DWORD PTR __real@3f800000
movss DWORD PTR _mat_a$[ebp+32], xmm0
movss xmm0, DWORD PTR __real@40000000
movss DWORD PTR _mat_a$[ebp+36], xmm0
movss xmm0, DWORD PTR __real@40400000
movss DWORD PTR _mat_a$[ebp+40], xmm0
movss xmm0, DWORD PTR __real@40800000
movss DWORD PTR _mat_a$[ebp+44], xmm0
movss xmm0, DWORD PTR __real@3f800000
movss DWORD PTR _mat_a$[ebp+48], xmm0
movss xmm0, DWORD PTR __real@40000000
movss DWORD PTR _mat_a$[ebp+52], xmm0
movss xmm0, DWORD PTR __real@40400000
movss DWORD PTR _mat_a$[ebp+56], xmm0
movss xmm0, DWORD PTR __real@40800000
movss DWORD PTR _mat_a$[ebp+60], xmm0
; Line 6
movss xmm0, DWORD PTR __real@3f800000
movss DWORD PTR _mat_b$[ebp], xmm0
movss xmm0, DWORD PTR __real@40000000
movss DWORD PTR _mat_b$[ebp+4], xmm0
movss xmm0, DWORD PTR __real@40400000
movss DWORD PTR _mat_b$[ebp+8], xmm0
movss xmm0, DWORD PTR __real@40800000
movss DWORD PTR _mat_b$[ebp+12], xmm0
movss xmm0, DWORD PTR __real@3f800000
movss DWORD PTR _mat_b$[ebp+16], xmm0
movss xmm0, DWORD PTR __real@40000000
movss DWORD PTR _mat_b$[ebp+20], xmm0
movss xmm0, DWORD PTR __real@40400000
movss DWORD PTR _mat_b$[ebp+24], xmm0
movss xmm0, DWORD PTR __real@40800000
movss DWORD PTR _mat_b$[ebp+28], xmm0
movss xmm0, DWORD PTR __real@3f800000
movss DWORD PTR _mat_b$[ebp+32], xmm0
movss xmm0, DWORD PTR __real@40000000
movss DWORD PTR _mat_b$[ebp+36], xmm0
movss xmm0, DWORD PTR __real@40400000
movss DWORD PTR _mat_b$[ebp+40], xmm0
movss xmm0, DWORD PTR __real@40800000
movss DWORD PTR _mat_b$[ebp+44], xmm0
movss xmm0, DWORD PTR __real@3f800000
movss DWORD PTR _mat_b$[ebp+48], xmm0
movss xmm0, DWORD PTR __real@40000000
movss DWORD PTR _mat_b$[ebp+52], xmm0
movss xmm0, DWORD PTR __real@40400000
movss DWORD PTR _mat_b$[ebp+56], xmm0
movss xmm0, DWORD PTR __real@40800000
movss DWORD PTR _mat_b$[ebp+60], xmm0
; Line 7
movss xmm0, DWORD PTR __real@00000000
movss DWORD PTR _result$[ebp], xmm0
movss xmm0, DWORD PTR __real@00000000
movss DWORD PTR _result$[ebp+4], xmm0
movss xmm0, DWORD PTR __real@00000000
movss DWORD PTR _result$[ebp+8], xmm0
movss xmm0, DWORD PTR __real@00000000
movss DWORD PTR _result$[ebp+12], xmm0
movss xmm0, DWORD PTR __real@00000000
movss DWORD PTR _result$[ebp+16], xmm0
movss xmm0, DWORD PTR __real@00000000
movss DWORD PTR _result$[ebp+20], xmm0
movss xmm0, DWORD PTR __real@00000000
movss DWORD PTR _result$[ebp+24], xmm0
movss xmm0, DWORD PTR __real@00000000
movss DWORD PTR _result$[ebp+28], xmm0
movss xmm0, DWORD PTR __real@00000000
movss DWORD PTR _result$[ebp+32], xmm0
movss xmm0, DWORD PTR __real@00000000
movss DWORD PTR _result$[ebp+36], xmm0
movss xmm0, DWORD PTR __real@00000000
movss DWORD PTR _result$[ebp+40], xmm0
movss xmm0, DWORD PTR __real@00000000
movss DWORD PTR _result$[ebp+44], xmm0
movss xmm0, DWORD PTR __real@00000000
movss DWORD PTR _result$[ebp+48], xmm0
movss xmm0, DWORD PTR __real@00000000
movss DWORD PTR _result$[ebp+52], xmm0
movss xmm0, DWORD PTR __real@00000000
movss DWORD PTR _result$[ebp+56], xmm0
movss xmm0, DWORD PTR __real@00000000
movss DWORD PTR _result$[ebp+60], xmm0
; Line 9
mov DWORD PTR _num_row$[ebp], 4
; Line 10
mov DWORD PTR _num_col$[ebp], 4
; Line 14
mov DWORD PTR _i$3[ebp], 0
jmp SHORT $LN9@main
$LN8@main:
mov eax, DWORD PTR _i$3[ebp]
add eax, 1
mov DWORD PTR _i$3[ebp], eax
$LN9@main:
mov eax, DWORD PTR _i$3[ebp]
cmp eax, DWORD PTR _num_row$[ebp]
jge $LN7@main
; Line 15
mov DWORD PTR _j$2[ebp], 0
jmp SHORT $LN6@main
$LN5@main:
mov eax, DWORD PTR _j$2[ebp]
add eax, 1
mov DWORD PTR _j$2[ebp], eax
$LN6@main:
mov eax, DWORD PTR _j$2[ebp]
cmp eax, DWORD PTR _num_col$[ebp]
jge $LN4@main
; Line 16
movss xmm0, DWORD PTR __real@00000000
movss DWORD PTR _sum$[ebp], xmm0
; Line 17
mov DWORD PTR _k$1[ebp], 0
jmp SHORT $LN3@main
$LN2@main:
mov eax, DWORD PTR _k$1[ebp]
add eax, 1
mov DWORD PTR _k$1[ebp], eax
$LN3@main:
mov eax, DWORD PTR _k$1[ebp]
cmp eax, DWORD PTR _num_row$[ebp]
jge SHORT $LN1@main
; Line 18
mov eax, DWORD PTR _i$3[ebp]
imul eax, DWORD PTR _num_col$[ebp]
add eax, DWORD PTR _k$1[ebp]
mov ecx, DWORD PTR _k$1[ebp]
imul ecx, DWORD PTR _num_col$[ebp]
add ecx, DWORD PTR _j$2[ebp]
movss xmm0, DWORD PTR _mat_a$[ebp+eax*4]
mulss xmm0, DWORD PTR _mat_b$[ebp+ecx*4]
addss xmm0, DWORD PTR _sum$[ebp]
movss DWORD PTR _sum$[ebp], xmm0
; Line 19
jmp SHORT $LN2@main
$LN1@main:
; Line 20
mov eax, DWORD PTR _i$3[ebp]
imul eax, DWORD PTR _num_col$[ebp]
lea ecx, DWORD PTR _result$[ebp+eax*4]
mov edx, DWORD PTR _j$2[ebp]
movss xmm0, DWORD PTR _sum$[ebp]
movss DWORD PTR [ecx+edx*4], xmm0
; Line 21
jmp $LN5@main
$LN4@main:
; Line 22
jmp $LN8@main
$LN7@main:
; Line 24
xor eax, eax
; Line 25
push edx
mov ecx, ebp
push eax
lea edx, DWORD PTR $LN16@main
call @_RTC_CheckStackVars@8
pop eax
pop edx
pop edi
pop esi
pop ebx
mov ecx, DWORD PTR __$ArrayPad$[ebp]
xor ecx, ebp
call @__security_check_cookie@4
mov esp, ebp
pop ebp
ret 0
npad 1
$LN16@main:
DD 3
DD $LN15@main
$LN15@main:
DD -72 ; ffffffb8H
DD 64 ; 00000040H
DD $LN12@main
DD -144 ; ffffff70H
DD 64 ; 00000040H
DD $LN13@main
DD -216 ; ffffff28H
DD 64 ; 00000040H
DD $LN14@main
$LN14@main:
DB 114 ; 00000072H
DB 101 ; 00000065H
DB 115 ; 00000073H
DB 117 ; 00000075H
DB 108 ; 0000006cH
DB 116 ; 00000074H
DB 0
$LN13@main:
DB 109 ; 0000006dH
DB 97 ; 00000061H
DB 116 ; 00000074H
DB 95 ; 0000005fH
DB 98 ; 00000062H
DB 0
$LN12@main:
DB 109 ; 0000006dH
DB 97 ; 00000061H
DB 116 ; 00000074H
DB 95 ; 0000005fH
DB 97 ; 00000061H
DB 0
_main ENDP
_TEXT ENDS
END
解决方法:
Visual Studio和SSE在这里是红色的鲱鱼(以及C和C的废话).假设您在发布模式下进行编译,还有其他原因导致您的代码效率低下,尤其是对于大型矩阵.主要原因是它的缓存不友好.为了使代码对于任意n * n矩阵有效,您需要针对大小进行优化.
在使用SIMD或线程之前,对缓存进行优化非常重要.在下面的代码中,我仅使用一个线程而不使用SSE / AVX,使用块乘法将1024×1204矩阵的代码速度提高了十倍以上(旧代码为7.1 s,新代码为0.6 s).如果您的代码受内存限制,则使用SIMD不会有任何好处.
我已经在这里描述了使用转置对矩阵乘法进行的一阶改进.
OpenMP C++ Matrix Multiplication run slower in parallel
但是,让我描述一个更加缓存友好的方法.假设您的硬件具有两种类型的内存:
>小而快速,
>大而慢.
实际上,现代CPU实际上具有多个级别(L1较小且较快,L2较大且较慢,L3较大且较慢,主内存甚至更大甚至更慢.有些CPU甚至具有L4),但是这种简单的模型这里只有两个级别仍会导致性能上的巨大改进.
将此模型与两种类型的内存一起使用,您可以证明,将矩阵划分为适合小而快速的内存的正方形瓦片并进行块矩阵乘法,您将获得最佳性能.接下来,您要重新排列内存,以使每个图块的元素连续.
下面的代码显示了如何执行此操作.我在1024×1024矩阵上使用了64×64的块大小.您的代码花费了7s,而我的花费了0.65s.矩阵大小必须是64×64的倍数,但是很容易将其扩展为任意大小的矩阵.如果要查看如何优化块的示例,请参见此Difference in performance between MSVC and GCC for highly optimized matrix multplication code
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <omp.h>
void reorder(float *a, float *b, int n, int bs) {
int nb = n/bs;
int cnt = 0;
for(int i=0; i<nb; i++) {
for(int j=0; j<nb; j++) {
for(int i2=0; i2<bs; i2++) {
for(int j2=0; j2<bs; j2++) {
b[cnt++] = a[bs*(i*n+j) + i2*n + j2];
}
}
}
}
}
void gemm_slow(float *a, float *b, float *c, int n) {
for(int i=0; i<n; i++) {
for(int j=0; j<n; j++) {
float sum = c[i*n+j];
for(int k=0; k<n; k++) {
sum += a[i*n+k]*b[k*n+j];
}
c[i*n+j] += sum;
}
}
}
void gemm_block(float *a, float *b, float *c, int n, int n2) {
for(int i=0; i<n2; i++) {
for(int j=0; j<n2; j++) {
float sum = c[i*n+j];
for(int k=0; k<n2; k++) {
sum += a[i*n+k]*b[k*n2+j];
}
c[i*n+j] = sum;
}
}
}
void gemm(float *a, float*b, float*c, int n, int bs) {
int nb = n/bs;
float *b2 = (float*)malloc(sizeof(float)*n*n);
reorder(b,b2,n,bs);
for(int i=0; i<nb; i++) {
for(int j=0; j<nb; j++) {
for(int k=0; k<nb; k++) {
gemm_block(&a[bs*(i*n+k)],&b2[bs*bs*(k*nb+j)],&c[bs*(i*n+j)], n, bs);
}
}
}
free(b2);
}
int main() {
const int bs = 64;
const int n = 1024;
float *a = new float[n*n];
float *b = new float[n*n];
float *c1 = new float[n*n]();
float *c2 = new float[n*n]();
for(int i=0; i<n*n; i++) {
a[i] = 1.0*rand()/RAND_MAX;
b[i] = 1.0*rand()/RAND_MAX;
}
double dtime;
dtime = omp_get_wtime();
gemm_slow(a,b,c1,n);
dtime = omp_get_wtime() - dtime;
printf("%f\n", dtime);
dtime = omp_get_wtime();
gemm(a,b,c2,n,64);
dtime = omp_get_wtime() - dtime;
printf("%f\n", dtime);
printf("%d\n", memcmp(c1,c2, sizeof(float)*n*n));
}