-
Notifications
You must be signed in to change notification settings - Fork 212
Adding Tensor.roll and torch.roll #621 #623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
/// Elements that are shifted beyond the last position are re-introduced at the first position. | ||
/// If a dimension is not specified, the tensor will be flattened before rolling and then restored to the original shape. | ||
/// </summary> | ||
public Tensor roll((long,long) shifts, (long,long) dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't able to make the 'dims' argument a span, since it didn't let me default to 'null' Is there a solution to that? An empty span?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just pass empty span when having null dims. You can pass default
value for the dims parameter.
/// Elements that are shifted beyond the last position are re-introduced at the first position. | ||
/// If a dimension is not specified, the tensor will be flattened before rolling and then restored to the original shape. | ||
/// </summary> | ||
public Tensor roll((long,long) shifts, (long,long) dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/TorchSharp/Tensor/Tensor.cs
Outdated
/// Elements that are shifted beyond the last position are re-introduced at the first position. | ||
/// If a dimension is not specified, the tensor will be flattened before rolling and then restored to the original shape. | ||
/// </summary> | ||
public Tensor roll(IEnumerable<long> shifts, IEnumerable<long>? dims = null) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added minor comments and a suggestion if you want to consider it. LGTM, otherwise.
src/TorchSharp/Tensor/Tensor.cs
Outdated
public Tensor roll(long shifts, long? dims = null) | ||
{ | ||
if (dims.HasValue) { | ||
return roll(new long[] { shifts }, new long[] { dims.Value }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm... When I tried changing to ReadOnlySpan for a few other APIs (torch.zeros, for example), that seems to not work with F#.
@dsyme -- this doesn't work if torch.zeros takes a ReadOnlySpan:
let mutable pe = torch.zeros([| maxLen; dmodel|])
Do I need to overload ROS and IEnumerable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do I need to overload ROS and IEnumerable?
It seems I need to have roll(ReadOnlySpan,...) and roll(long[],...) as well as a private _roll(ReadOnlySpan...) with the actual implementation.
I've been looking for other places to use ReadOnlySpan, but I will put that into separate PR. |
Fixed issue #621