This repository implements a custom forward pass for a neural operation using CUDA and integrates it into PyTorch via C++/CUDA extensions. It supports:
- A naive implementation with custom CUDA kernels
- An optimized version using cuBLAS
- A PyTorch reference version for correctness and performance comparison
The operation consists of the following layers:
- Linear + Bias + ReLU:
x1 = ReLU(input @ weight1 + bias)
- Linear + Square:
x2 = (x1 @ weight2) ** 2
(element-wise) - Outer Product:
out = x2 ⊗ x2
→ shape:(batch, n, n)
File | Description |
---|---|
binding.cpp |
Pybind11 interface binding CUDA kernels to Python |
kernels_naive.cu |
Naive CUDA implementation of the forward pass |
kernels_optimized.cu |
Optimized CUDA version using cuBLAS |
wrapper.py |
Python interface to load and run kernels |
test.py |
Benchmarking and correctness testing |
run_test.sh |
Shell script to compile and run tests |
requirements.txt |
Python dependencies |
# Create a new conda environment with Python 3.11
conda create -n kernel_env python=3.11 -y
# Activate the environment
conda activate kernel_env
# Install additional dependencies from requirements.txt
pip install -r requirements.txt
To run the test directly on your local machine or login node:
python test.py
For cluster environments with SLURM job scheduler:
sbatch run_test.sh
This will submit a job to the cluster and run the tests on a compute node with GPU access.
The project implements three custom CUDA kernels:
- Dense ReLU Layer (
dense_relu_k
) - Performs matrix multiplication followed by bias addition and ReLU activation - Dense Square Layer (
dense_square_k
) - Performs matrix multiplication and squares the result - Outer Product Layer (
outer_prod_k
) - Computes the outer product of input vectors
- Ninja not found: Make sure ninja is installed via pip
- CUDA compilation errors: Ensure CUDA toolkit is properly installed and accessible
- PyTorch version mismatch: Verify PyTorch is compiled with CUDA support
- Python 3.11
- PyTorch >= 2.0.0 with CUDA support
- CUDA Toolkit >= 12.1
- Ninja build system
- C++ compiler with C++17 support
- The code uses PyTorch's JIT compilation to build the CUDA extensions at runtime
- Make sure you have appropriate GPU access when running the tests
- The kernels are optimized for educational purposes