-
Notifications
You must be signed in to change notification settings - Fork 31
Add hl.rand op with seed arg lowering to tl.rand #652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
stack-info: PR: #652, branch: karthickai/stack/2
04d0e9b
to
300de6b
Compare
seed: int, | ||
dtype: torch.dtype = torch.float32, | ||
device: torch.device | None = None, | ||
) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to comment more on when user should use hl.rand
vs. torch.rand_like
(#530)? IMO ideally we should minimize the hl.*
API surface and encourage reuse of existing torch.*
APIs. But if there are cases where torch API would not work, then it makes sense to add hl.rand
API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure I'll add the comment clarifying the intent hl.rand
exists to explicitly pass a seed arg for deterministic randomness in helion kernels, whereas torch.rand_like
doesn't take seed arg.
stack-info: PR: #652, branch: karthickai/stack/2
300de6b
to
d82cc97
Compare
stack-info: PR: #652, branch: karthickai/stack/2
d82cc97
to
0d507e3
Compare
test/test_rng.py
Outdated
def test_hl_rand_3d(self): | ||
import helion | ||
|
||
@helion.kernel(ref_mode=helion.RefMode.EAGER) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe remove the ref_mode setting? The code as-is will explicitly test ref eager mode instead of normal Helion compile mode, but here I believe the intent is to test compile mode. (We have other harness to run tests in ref eager mode automatically so usually we don't need to worry about it.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the catch! I added that line for debugging to check the ref implementation and forgot to remove it. I’ve updated it now.
test/test_rng.py
Outdated
self.assertTrue(torch.all(output < 1.0), "All values should be < 1") | ||
|
||
def test_hl_rand_3d(self): | ||
import helion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can likely remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed that line.
stack-info: PR: #652, branch: karthickai/stack/2
0d507e3
to
c112edb
Compare
Stacked PRs:
Add hl.rand op with seed arg lowering to tl.rand