|
|
|
|
|
|
|
|
|
|
|
|
|
#include <torch/types.h> |
|
|
|
#include <ATen/ATen.h> |
|
#include <ATen/AccumulateType.h> |
|
#include <ATen/cuda/CUDAApplyUtils.cuh> |
|
#include <ATen/cuda/CUDAContext.h> |
|
|
|
|
|
#include <cuda.h> |
|
#include <cuda_runtime.h> |
|
|
|
template <typename scalar_t> |
|
static __global__ void |
|
fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b, |
|
const scalar_t *p_ref, int act, int grad, scalar_t alpha, |
|
scalar_t scale, int loop_x, int size_x, int step_b, |
|
int size_b, int use_bias, int use_ref) { |
|
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; |
|
|
|
scalar_t zero = 0.0; |
|
|
|
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; |
|
loop_idx++, xi += blockDim.x) { |
|
scalar_t x = p_x[xi]; |
|
|
|
if (use_bias) { |
|
x += p_b[(xi / step_b) % size_b]; |
|
} |
|
|
|
scalar_t ref = use_ref ? p_ref[xi] : zero; |
|
|
|
scalar_t y; |
|
|
|
switch (act * 10 + grad) { |
|
default: |
|
case 10: |
|
y = x; |
|
break; |
|
case 11: |
|
y = x; |
|
break; |
|
case 12: |
|
y = 0.0; |
|
break; |
|
|
|
case 30: |
|
y = (x > 0.0) ? x : x * alpha; |
|
break; |
|
case 31: |
|
y = (ref > 0.0) ? x : x * alpha; |
|
break; |
|
case 32: |
|
y = 0.0; |
|
break; |
|
} |
|
|
|
out[xi] = y * scale; |
|
} |
|
} |
|
|
|
torch::Tensor fused_bias_act_op(const torch::Tensor &input, |
|
const torch::Tensor &bias, |
|
const torch::Tensor &refer, int act, int grad, |
|
float alpha, float scale) { |
|
int curDevice = -1; |
|
cudaGetDevice(&curDevice); |
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
|
|
|
auto x = input.contiguous(); |
|
auto b = bias.contiguous(); |
|
auto ref = refer.contiguous(); |
|
|
|
int use_bias = b.numel() ? 1 : 0; |
|
int use_ref = ref.numel() ? 1 : 0; |
|
|
|
int size_x = x.numel(); |
|
int size_b = b.numel(); |
|
int step_b = 1; |
|
|
|
for (int i = 1 + 1; i < x.dim(); i++) { |
|
step_b *= x.size(i); |
|
} |
|
|
|
int loop_x = 4; |
|
int block_size = 4 * 32; |
|
int grid_size = (size_x - 1) / (loop_x * block_size) + 1; |
|
|
|
auto y = torch::empty_like(x); |
|
|
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
|
x.scalar_type(), "fused_bias_act_kernel", [&] { |
|
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>( |
|
y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), |
|
b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha, |
|
scale, loop_x, size_x, step_b, size_b, use_bias, use_ref); |
|
}); |
|
|
|
return y; |
|
} |