上代码:
/*
* ===========================================================================================
*
* Filename: transpose.c
*
* Description: transpose operator impl.
*
* Version:
* Create: 2021-11-07 14:08:50
* Revision: none
* Compiler: GCC:version 7.2.1 20170904 (release),ARM/embedded-7-branch revision 255204
*
* Author:
* Organization:
Last Modified : 2021-11-07 20:22:56
*
* ===========================================================================================
*/
#include <stdio.h>
#include <stdlib.h>
#define DBG(fmt, ...) do{ printf("%s line %d, "fmt"\n", __func__, __LINE__, ##__VA_ARGS__); } while(0)
void transpose_matrix(float *matrix_A, float **matrix_B, int *shape_A, int dims_A, int *pos)
{
float *B;
int element_count;
int i;
element_count = 1;
for(i = 0; i < dims_A; i ++)
{
element_count *= shape_A[i];
}
B = (float *)malloc(element_count * sizeof(float));
if(B == NULL)
{
DBG("malloc buffer for B failure.");
return;
}
int* shape_B = (int *)malloc(sizeof(int) * dims_A);
if(shape_B == NULL)
{
DBG("malloc shape buffer for B failure.");
return;
}
for(int i = 0; i < dims_A; i++)
{
shape_B[i] = shape_A[pos[i]];
}
int* indexA = (int*)malloc(sizeof(int) * dims_A);
if(indexA == NULL)
{
DBG("failure to malloc matrix A index.");
return;
}
int* indexB = (int*)malloc(sizeof(int) * dims_A);
if(indexB == NULL)
{
DBG("failure to malloc matrix B index.");
return;
}
for(int src = 0; src < element_count; src++)
{
int temp = src;
for(i = dims_A-1; i >= 0; i--)
{
indexA[i] = temp % shape_A[i];
temp = temp / shape_A[i];
}
for(i = 0; i < dims_A; i++)
{
indexB[i] = indexA[pos[i]];
}
int dst = 0;
temp = 1;
for(i = dims_A - 1; i >= 0; i--)
{
dst = dst + indexB[i] * temp;
temp = temp * shape_B[i];
}
B[dst] = matrix_A[src];
}
free(indexA);
free(indexB);
indexA = indexB = NULL;
*matrix_B=B;
return;
}
void print_tensor(const float* A, int* shape, int dim)
{
int elem = 1;
for(int i = 0; i < dim; i++)
{
elem = elem * shape[i];
}
printf("Array size: %d\n", elem);
for(int i = 0; i < elem; i++)
{
printf( "%f ", A[i] );
int split = 1;
for(int j = dim-1; j > 0; j--)
{
split = split * shape[j];
if( (i+1) % split == 0)
{
printf("\n");
}
}
}
}
int main(void)
{
float* B;
float A[24] =
{
1, 2, 3, 4,
5, 6, 7, 8,
9, 10, 11, 12,
13, 14, 15, 16,
17, 18, 19, 20,
21, 22, 23, 24
};
int shapeA[] = {2, 3, 4};
int dimA = 3;
print_tensor(A, shapeA, dimA);
// Transpose
int perm[] = { 2, 0, 1};
transpose_matrix(A, &B, shapeA, dimA, perm);
// Print B
int shapeB[] = {4, 2, 3};
int dimB = 3;
print_tensor(B, shapeB, dimB);
int shapeM[] = {2, 2, 2, 3};
int dimM = 4;
print_tensor(A, shapeM, dimM);
// Transpose
int permM[] = {3, 0, 1, 2};
transpose_matrix(A, &B, shapeM, dimM, permM);
// Print B
int shapeO[] = {3, 2, 2, 2};
int dimO = 4;
print_tensor(B, shapeO, dimO);
// Free memory
free(B);
return 0;
}
运行结果
(base) caozilong@caozilong-Vostro-3268:~/Workspace/transpose$ ./a.out
Array size: 24
1.000000 2.000000 3.000000 4.000000
5.000000 6.000000 7.000000 8.000000
9.000000 10.000000 11.000000 12.000000
13.000000 14.000000 15.000000 16.000000
17.000000 18.000000 19.000000 20.000000
21.000000 22.000000 23.000000 24.000000
Array size: 24
1.000000 5.000000 9.000000
13.000000 17.000000 21.000000
2.000000 6.000000 10.000000
14.000000 18.000000 22.000000
3.000000 7.000000 11.000000
15.000000 19.000000 23.000000
4.000000 8.000000 12.000000
16.000000 20.000000 24.000000
Array size: 24
1.000000 2.000000 3.000000
4.000000 5.000000 6.000000
7.000000 8.000000 9.000000
10.000000 11.000000 12.000000
13.000000 14.000000 15.000000
16.000000 17.000000 18.000000
19.000000 20.000000 21.000000
22.000000 23.000000 24.000000
Array size: 24
1.000000 4.000000
7.000000 10.000000
13.000000 16.000000
19.000000 22.000000
2.000000 5.000000
8.000000 11.000000
14.000000 17.000000
20.000000 23.000000
3.000000 6.000000
9.000000 12.000000
15.000000 18.000000
21.000000 24.000000
(base) caozilong@caozilong-Vostro-3268:~/Workspace/transpose$