@@ -37,64 +37,59 @@ __device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) {
3737
3838template <typename scalar_t >
3939__global__ void lltm_cuda_forward_kernel (
40- const scalar_t * __restrict__ gates,
41- const scalar_t * __restrict__ old_cell,
42- scalar_t * __restrict__ new_h,
43- scalar_t * __restrict__ new_cell,
44- scalar_t * __restrict__ input_gate,
45- scalar_t * __restrict__ output_gate,
46- scalar_t * __restrict__ candidate_cell,
47- size_t state_size) {
48- const int column = blockIdx .x * blockDim . x + threadIdx . x ;
49- const int index = blockIdx . y * state_size + column;
50- const int gates_row = blockIdx .y * (state_size * 3 ) ;
51- if (column < state_size) {
52- input_gate[index] = sigmoid (gates[gates_row + column ]);
53- output_gate[index] = sigmoid (gates[gates_row + state_size + column ]);
54- candidate_cell[index] = elu (gates[gates_row + 2 * state_size + column ]);
55- new_cell[index ] =
56- old_cell[index] + candidate_cell[index] * input_gate[index ];
57- new_h[index] = tanh (new_cell[index] ) * output_gate[index ];
40+ const torch::PackedTensorAccessor< scalar_t , 3 ,torch::RestrictPtrTraits, size_t > gates,
41+ const torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > old_cell,
42+ torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > new_h,
43+ torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > new_cell,
44+ torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > input_gate,
45+ torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > output_gate,
46+ torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > candidate_cell) {
47+ // batch index
48+ const int n = blockIdx .y ;
49+ // column index
50+ const int c = blockIdx .x * blockDim . x + threadIdx . x ;
51+ if (c < gates. size ( 2 )) {
52+ input_gate[n][c] = sigmoid (gates[n][ 0 ][c ]);
53+ output_gate[n][c] = sigmoid (gates[n][ 1 ][c ]);
54+ candidate_cell[n][c] = elu (gates[n][ 2 ][c ]);
55+ new_cell[n][c ] =
56+ old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c ];
57+ new_h[n][c] = tanh (new_cell[n][c] ) * output_gate[n][c ];
5858 }
5959}
6060
6161template <typename scalar_t >
6262__global__ void lltm_cuda_backward_kernel (
63- scalar_t * __restrict__ d_old_cell,
64- scalar_t * __restrict__ d_gates,
65- const scalar_t * __restrict__ grad_h,
66- const scalar_t * __restrict__ grad_cell,
67- const scalar_t * __restrict__ new_cell,
68- const scalar_t * __restrict__ input_gate,
69- const scalar_t * __restrict__ output_gate,
70- const scalar_t * __restrict__ candidate_cell,
71- const scalar_t * __restrict__ gate_weights,
72- size_t state_size) {
73- const int column = blockIdx .x * blockDim . x + threadIdx . x ;
74- const int index = blockIdx . y * state_size + column;
75- const int gates_row = blockIdx .y * (state_size * 3 ) ;
76- if (column < state_size) {
77- const auto d_output_gate = tanh (new_cell[index] ) * grad_h[index ];
78- const auto d_tanh_new_cell = output_gate[index] * grad_h[index ];
63+ torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > d_old_cell,
64+ torch::PackedTensorAccessor< scalar_t , 3 ,torch::RestrictPtrTraits, size_t > d_gates,
65+ const torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > grad_h,
66+ const torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > grad_cell,
67+ const torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > new_cell,
68+ const torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > input_gate,
69+ const torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > output_gate,
70+ const torch::PackedTensorAccessor< scalar_t , 2 ,torch::RestrictPtrTraits, size_t > candidate_cell,
71+ const torch::PackedTensorAccessor< scalar_t , 3 ,torch::RestrictPtrTraits, size_t > gate_weights) {
72+ // batch index
73+ const int n = blockIdx .y ;
74+ // column index
75+ const int c = blockIdx .x * blockDim . x + threadIdx . x ;
76+ if (c < d_gates. size ( 2 )) {
77+ const auto d_output_gate = tanh (new_cell[n][c] ) * grad_h[n][c ];
78+ const auto d_tanh_new_cell = output_gate[n][c] * grad_h[n][c ];
7979 const auto d_new_cell =
80- d_tanh (new_cell[index] ) * d_tanh_new_cell + grad_cell[index ];
80+ d_tanh (new_cell[n][c] ) * d_tanh_new_cell + grad_cell[n][c ];
8181
8282
83- d_old_cell[index ] = d_new_cell;
84- const auto d_candidate_cell = input_gate[index ] * d_new_cell;
85- const auto d_input_gate = candidate_cell[index ] * d_new_cell;
83+ d_old_cell[n][c ] = d_new_cell;
84+ const auto d_candidate_cell = input_gate[n][c ] * d_new_cell;
85+ const auto d_input_gate = candidate_cell[n][c ] * d_new_cell;
8686
87-
88- const auto input_gate_index = gates_row + column;
89- const auto output_gate_index = gates_row + state_size + column;
90- const auto candidate_cell_index = gates_row + 2 * state_size + column;
91-
92- d_gates[input_gate_index] =
93- d_input_gate * d_sigmoid (gate_weights[input_gate_index]);
94- d_gates[output_gate_index] =
95- d_output_gate * d_sigmoid (gate_weights[output_gate_index]);
96- d_gates[candidate_cell_index] =
97- d_candidate_cell * d_elu (gate_weights[candidate_cell_index]);
87+ d_gates[n][0 ][c] =
88+ d_input_gate * d_sigmoid (gate_weights[n][0 ][c]);
89+ d_gates[n][1 ][c] =
90+ d_output_gate * d_sigmoid (gate_weights[n][1 ][c]);
91+ d_gates[n][2 ][c] =
92+ d_candidate_cell * d_elu (gate_weights[n][2 ][c]);
9893 }
9994}
10095} // namespace
@@ -106,11 +101,12 @@ std::vector<torch::Tensor> lltm_cuda_forward(
106101 torch::Tensor old_h,
107102 torch::Tensor old_cell) {
108103 auto X = torch::cat ({old_h, input}, /* dim=*/ 1 );
109- auto gates = torch::addmm (bias, X, weights.transpose (0 , 1 ));
104+ auto gate_weights = torch::addmm (bias, X, weights.transpose (0 , 1 ));
110105
111106 const auto batch_size = old_cell.size (0 );
112107 const auto state_size = old_cell.size (1 );
113108
109+ auto gates = gate_weights.reshape ({batch_size, 3 , state_size});
114110 auto new_h = torch::zeros_like (old_cell);
115111 auto new_cell = torch::zeros_like (old_cell);
116112 auto input_gate = torch::zeros_like (old_cell);
@@ -122,14 +118,13 @@ std::vector<torch::Tensor> lltm_cuda_forward(
122118
123119 AT_DISPATCH_FLOATING_TYPES (gates.type (), " lltm_forward_cuda" , ([&] {
124120 lltm_cuda_forward_kernel<scalar_t ><<<blocks, threads>>> (
125- gates.data <scalar_t >(),
126- old_cell.data <scalar_t >(),
127- new_h.data <scalar_t >(),
128- new_cell.data <scalar_t >(),
129- input_gate.data <scalar_t >(),
130- output_gate.data <scalar_t >(),
131- candidate_cell.data <scalar_t >(),
132- state_size);
121+ gates.packed_accessor <scalar_t ,3 ,torch::RestrictPtrTraits,size_t >(),
122+ old_cell.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
123+ new_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
124+ new_cell.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
125+ input_gate.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
126+ output_gate.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
127+ candidate_cell.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >());
133128 }));
134129
135130 return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
@@ -143,10 +138,10 @@ std::vector<torch::Tensor> lltm_cuda_backward(
143138 torch::Tensor output_gate,
144139 torch::Tensor candidate_cell,
145140 torch::Tensor X,
146- torch::Tensor gate_weights ,
141+ torch::Tensor gates ,
147142 torch::Tensor weights) {
148143 auto d_old_cell = torch::zeros_like (new_cell);
149- auto d_gates = torch::zeros_like (gate_weights );
144+ auto d_gates = torch::zeros_like (gates );
150145
151146 const auto batch_size = new_cell.size (0 );
152147 const auto state_size = new_cell.size (1 );
@@ -156,22 +151,22 @@ std::vector<torch::Tensor> lltm_cuda_backward(
156151
157152 AT_DISPATCH_FLOATING_TYPES (X.type (), " lltm_forward_cuda" , ([&] {
158153 lltm_cuda_backward_kernel<scalar_t ><<<blocks, threads>>> (
159- d_old_cell.data <scalar_t >(),
160- d_gates.data <scalar_t >(),
161- grad_h.data <scalar_t >(),
162- grad_cell.data <scalar_t >(),
163- new_cell.data <scalar_t >(),
164- input_gate.data <scalar_t >(),
165- output_gate.data <scalar_t >(),
166- candidate_cell.data <scalar_t >(),
167- gate_weights.data <scalar_t >(),
168- state_size);
154+ d_old_cell.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
155+ d_gates.packed_accessor <scalar_t ,3 ,torch::RestrictPtrTraits,size_t >(),
156+ grad_h.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
157+ grad_cell.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
158+ new_cell.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
159+ input_gate.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
160+ output_gate.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
161+ candidate_cell.packed_accessor <scalar_t ,2 ,torch::RestrictPtrTraits,size_t >(),
162+ gates.packed_accessor <scalar_t ,3 ,torch::RestrictPtrTraits,size_t >());
169163 }));
170164
171- auto d_weights = d_gates.t ().mm (X);
172- auto d_bias = d_gates.sum (/* dim=*/ 0 , /* keepdim=*/ true );
165+ auto d_gate_weights = d_gates.flatten (1 , 2 );
166+ auto d_weights = d_gate_weights.t ().mm (X);
167+ auto d_bias = d_gate_weights.sum (/* dim=*/ 0 , /* keepdim=*/ true );
173168
174- auto d_X = d_gates .mm (weights);
169+ auto d_X = d_gate_weights .mm (weights);
175170 auto d_old_h = d_X.slice (/* dim=*/ 1 , 0 , state_size);
176171 auto d_input = d_X.slice (/* dim=*/ 1 , state_size);
177172
0 commit comments