Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions example/12_reduce/reduce_blockwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,6 @@ class SimpleAppArgs

int main(int argc, char* argv[])
{
using namespace ck::host_reduce;

const std::vector<int> reduceDims{0, 1, 2};
const std::vector<int> invariantDims{3};

Expand Down Expand Up @@ -254,7 +252,9 @@ int main(int argc, char* argv[])
ReductionHost<InDataType,
AccDataType,
OutDataType,
ReduceOpId,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
Rank,
NumReduceDim,
PropagateNan,
Expand Down
6 changes: 3 additions & 3 deletions example/12_reduce/reduce_blockwise_two_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ int main(int argc, char* argv[])

const std::vector<size_t> outLengths = {64, 320, 80};

using namespace ck::host_reduce;

if(argc == 1)
{
do_verify = true;
Expand Down Expand Up @@ -191,7 +189,9 @@ int main(int argc, char* argv[])
ReductionHost<InOutDataType,
AccDataType,
InOutDataType,
ReduceOpId,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
5, // Rank
2, // NumReduceDim
PropagateNan,
Expand Down
46 changes: 26 additions & 20 deletions example/13_pool2d_fwd/pool2d_fwd_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_reduce_util.hpp"
#include "device_tensor.hpp"
#include "tensor_layout.hpp"
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "reduction_functions_accumulate.hpp"

#include "device_pool2d_fwd_nhwc_nhwc.hpp"

template <typename InDataType,
Expand All @@ -29,19 +31,24 @@ static void pool_host_verify(const Tensor<InDataType>& in,
const std::array<ck::index_t, 2>& in_left_pads,
const std::array<ck::index_t, 2>& /*in_right_pads*/)
{
using namespace ck::host_reduce;

const int32_t divider = window_spatial_lengths[0] * window_spatial_lengths[1];

const auto PreUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider);
const auto PosUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider);
using ReduceOperation = typename ck::reduce_binary_operator<AccDataType, ReduceOpId>::opType;
using InElementwiseOperation = typename ck::
reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation = typename ck::
reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation;

const InElementwiseOperation in_elementwise_op(divider);
const AccElementwiseOperation acc_elementwise_op(divider);

if constexpr(!OutputIndex)
{
auto opReduce = ReduceOpFn<AccDataType, ReduceOpId>();
using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;

auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
auto accuVal = ReduceOperation::GetIdentityValue();

for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
{
Expand All @@ -54,14 +61,14 @@ static void pool_host_verify(const Tensor<InDataType>& in,
{
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));

PreUnaryOp(currVal);
in_elementwise_op(currVal, currVal);

binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal);
Accumulation::Calculate(accuVal, currVal);
}
}
}

PosUnaryOp(accuVal);
acc_elementwise_op(accuVal, accuVal);

out(n, c, ho, wo) = accuVal;
};
Expand All @@ -74,10 +81,12 @@ static void pool_host_verify(const Tensor<InDataType>& in,
}
else
{
auto opReduce = ReduceOpFn2<AccDataType, ReduceOpId>();

auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::GetIdentityValue();
IndexDataType accuIndex = 0;

for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
Expand All @@ -92,15 +101,14 @@ static void pool_host_verify(const Tensor<InDataType>& in,
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
IndexDataType currIndex = y * window_spatial_lengths[1] + x;

PreUnaryOp(currVal);
in_elementwise_op(currVal, currVal);

binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>(
opReduce, accuVal, currVal, accuIndex, currIndex);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
}
}
}

PosUnaryOp(accuVal);
acc_elementwise_op(accuVal, accuVal);

out(n, c, ho, wo) = accuVal;
out_indices(n, c, ho, wo) = accuIndex;
Expand Down Expand Up @@ -139,8 +147,6 @@ bool pool_test(bool do_verification,
ck::index_t in_right_pad_h,
ck::index_t in_right_pad_w)
{
using namespace ck::host_reduce;

using DevicePoolFwdInstance =
ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<
InDataType, // InDataType
Expand Down
2 changes: 0 additions & 2 deletions example/13_pool2d_fwd/pool2d_fwd_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false;

int main(int argc, char* argv[])
{
using namespace ck::host_reduce;

bool do_verification;
int init_method;
bool time_kernel;
Expand Down
2 changes: 0 additions & 2 deletions example/13_pool2d_fwd/pool2d_fwd_fp32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false;

int main(int argc, char* argv[])
{
using namespace ck::host_reduce;

bool do_verification;
int init_method;
bool time_kernel;
Expand Down
2 changes: 1 addition & 1 deletion example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ int main(int argc, char* argv[])

for(int m = 0; m < M; ++m)
{
ReduceAccDataType d_acc = d_reduce_op.GetReductionZeroVal();
ReduceAccDataType d_acc = d_reduce_op.GetIdentityValue();

for(int n = 0; n < N; ++n)
d_reduce_op(d_acc, c_m_n_host_result(m, n));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ int main(int argc, char* argv[])

for(int m = 0; m < M; ++m)
{
float d0_acc = d0_reduce_op.GetReductionZeroVal();
float d1_acc = d1_reduce_op.GetReductionZeroVal();
float d0_acc = d0_reduce_op.GetIdentityValue();
float d1_acc = d1_reduce_op.GetIdentityValue();

for(int n = 0; n < N; ++n)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ int main(int argc, char* argv[])
{
for(int m = 0; m < M; ++m)
{
float d0_acc = d0_reduce_op.GetReductionZeroVal();
float d1_acc = d1_reduce_op.GetReductionZeroVal();
float d0_acc = d0_reduce_op.GetIdentityValue();
float d1_acc = d1_reduce_op.GetIdentityValue();

for(int n = 0; n < N; ++n)
{
Expand Down
4 changes: 2 additions & 2 deletions example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
auto reduceSumOpInst = ReduceSumOp{};
for(int m = 0; m < M; ++m)
{
float mean_acc = reduceSumOpInst.GetReductionZeroVal();
float square_mean_acc = reduceSumOpInst.GetReductionZeroVal();
float mean_acc = reduceSumOpInst.GetIdentityValue();
float square_mean_acc = reduceSumOpInst.GetIdentityValue();

for(int n = 0; n < N; ++n)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE

if constexpr(use_multiblock)
{
const auto zeroVal =
ck::reduce::GetReductionZeroValueForInMemoryDataOperation<OutDataType>(
const auto identityVal =
ck::reduce::GetIdentityValueueForInMemoryDataOperation<OutDataType>(
OutMemoryDataOperation);

const auto kernel_pre =
Expand All @@ -362,7 +362,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
0,
out_grid_desc_m_2,
arg.out_dev_,
zeroVal);
identityVal);
};

avg_time += launch_and_time_kernel(stream_config,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include "data_type.hpp"
#include "math_v2.hpp"

namespace ck {
namespace tensor_operation {
Expand Down Expand Up @@ -296,36 +297,31 @@ struct UnaryAbs<float, float>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };

__host__ __device__ void operator()(float& y, const float& x) const { y = abs(x); };
__host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::abs(x); };
};

template <>
struct UnaryAbs<half_t, half_t>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };

__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = __habs(x); };
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = ck::math::abs(x); };
};

template <>
struct UnaryAbs<double, double>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };

__host__ __device__ void operator()(double& y, const double& x) const { y = abs(x); };
__host__ __device__ void operator()(double& y, const double& x) const { y = ck::math::abs(x); };
};

template <>
struct UnaryAbs<int8_t, int8_t>
{
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };

__host__ __device__ void operator()(int8_t& y, const int8_t& x) const
{
int8_t sgn = x >> (8 - 1);

y = (x ^ sgn) - sgn;
};
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = ck::math::abs(x); };
};

template <typename Y, typename X>
Expand All @@ -336,15 +332,18 @@ struct UnarySqrt<float, float>
{
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };

__host__ __device__ void operator()(float& y, const float& x) const { y = sqrtf(x); };
__host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::sqrt(x); };
};

template <>
struct UnarySqrt<double, double>
{
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };

__host__ __device__ void operator()(double& y, const double& x) const { y = sqrt(x); };
__host__ __device__ void operator()(double& y, const double& x) const
{
y = ck::math::sqrt(x);
};
};

} // namespace element_wise
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,15 @@ struct GridwiseReduction_mk_to_m_multiblock
AccDataType beta,
OutDataType* const __restrict__ p_out_value_global)
{
const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto identityVal = ReduceOperation::GetIdentityValue();

// LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize];

const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
type_convert<InDataType>(identityVal));
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());

Expand All @@ -191,7 +191,7 @@ struct GridwiseReduction_mk_to_m_multiblock

StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;

static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; });
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });

const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id();
Expand Down Expand Up @@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];

const auto zeroVal = ReduceOperation::GetReductionZeroVal();
const auto identityVal = ReduceOperation::GetIdentityValue();

const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal));
type_convert<InDataType>(identityVal));
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
Expand Down Expand Up @@ -418,7 +418,7 @@ struct GridwiseReduction_mk_to_m_multiblock
thread_k_cluster_id * KThreadSliceSize));

static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal;
accu_value_buf(I) = identityVal;
accu_index_buf(I) = 0;
});

Expand Down Expand Up @@ -459,7 +459,7 @@ struct GridwiseReduction_mk_to_m_multiblock
in_thread_idx_buf);

static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType tmpValue = zeroVal;
AccDataType tmpValue = identityVal;
IndexDataType tmpIndex = 0;

static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
Expand Down Expand Up @@ -512,7 +512,7 @@ struct GridwiseReduction_mk_to_m_multiblock
in_thread_val_buf(Number<offset>{}));
});

AccDataType tmpValue = zeroVal;
AccDataType tmpValue = identityVal;
IndexDataType tmpIndex = 0;

static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
Expand Down
Loading