diff --git a/problems/bioml/trimul/reference.py b/problems/bioml/trimul/reference.py index 86c1ee1..e02653d 100644 --- a/problems/bioml/trimul/reference.py +++ b/problems/bioml/trimul/reference.py @@ -134,7 +134,9 @@ def generate_input( # Generate input tensor based on distribution if distribution == "cauchy": # Heavier tail distribution - input_tensor = torch.distributions.Cauchy(0, 2).sample( + zero = torch.tensor(0.0, device="cuda") + two = torch.tensor(2.0, device="cuda") + input_tensor = torch.distributions.Cauchy(zero, two).sample( (batch_size, seq_len, seq_len, dim) ).to(device='cuda', dtype=torch.float32) else: # normal distribution @@ -165,4 +167,4 @@ def generate_input( return (input_tensor, mask, weights, config) -check_implementation = make_match_reference(ref_kernel, rtol=2e-2, atol=2e-2) \ No newline at end of file +check_implementation = make_match_reference(ref_kernel, rtol=2e-2, atol=2e-2)