Skip to content

Commit

Permalink
[PYTHON] Fixed formatting issue in conv.c
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet committed Mar 26, 2021
1 parent 8e15a54 commit 6c2e3d0
Showing 1 changed file with 10 additions and 18 deletions.
28 changes: 10 additions & 18 deletions python/triton/ops/conv.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ __global__ void conv(TYPE *A __noalias __readonly,
// memory strides
int lda_z, int lda_ci, int lda_h, int lda_w,
int ldb_ci, int ldb_r, int ldb_s, int ldb_co,
int ldc_z, int ldc_co, int ldc_p, int ldc_q) {
int ldc_z, int ldc_co, int ldc_p, int ldc_q)
{
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
Expand Down Expand Up @@ -47,19 +48,13 @@ __global__ void conv(TYPE *A __noalias __readonly,
int rw[TM, TK] = rw_0[:, newaxis] + rs [newaxis, :];

// pointers to lhs
int offa[TM, TK] = rz[:, newaxis] * lda_z +
rci [newaxis, :] * lda_ci +
rh * lda_h +
rw * 1;
int offa[TM, TK] = rz[:, newaxis] * lda_z + rci [newaxis, :] * lda_ci +
rh * lda_h + rw * 1;
TYPE *pa[TM, TK] = A + offa;
int *padelta[TK] = ADELTA + rk;
// pointers to rhs
int offb[TK, TN] = rci[:, newaxis] * ldb_ci +
rr
[:, newaxis] * ldb_r +
rs
[:, newaxis] * ldb_s +
rn [newaxis, :] * 1;
int offb[TK, TN] = rci[:, newaxis] * ldb_ci + rr[:, newaxis] * ldb_r +
rs[:, newaxis] * ldb_s + rn [newaxis, :] * 1;
TYPE *pb[TK, TN] = B + offb;

// prefetches operands
Expand All @@ -72,7 +67,8 @@ __global__ void conv(TYPE *A __noalias __readonly,

// reduction loop
float acc[TM, TN] = 0;
for (int k = K; k > 0; k -= TK) {
for (int k = K; k > 0; k -= TK)
{
acc += a @b;
// increment A
int adelta[TK] = *padelta;
Expand Down Expand Up @@ -103,12 +99,8 @@ __global__ void conv(TYPE *A __noalias __readonly,
rzp = rm / QQ;
rp = rzp % PP;
rz = rzp / PP;
int offc[TM, TN] = rz[:, newaxis] * ldc_z +
rn [newaxis, :] * ldc_co +
rp
[:, newaxis] * ldc_p +
rq
[:, newaxis] * 1;
int offc[TM, TN] = rz[:, newaxis] * ldc_z + rn [newaxis, :] * ldc_co +
rp[:, newaxis] * ldc_p + rq[:, newaxis] * 1;
TYPE *pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rm[:, newaxis] < M && rn [newaxis, :] < N;

Expand Down

0 comments on commit 6c2e3d0

Please sign in to comment.