Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
2ff8ec4
SixLabors.ImageSharp version bump.
masaru-kimura-hacarus Sep 10, 2025
5b821cb
Dispose `Scalar`s implicitly created in `Tensor` operators
ds5678 Jan 10, 2025
069bb0d
Use Scalar operators in the primitive overloads for clarity and maint…
ds5678 Jan 10, 2025
5c79729
Introduce TorchSharp.ScalarLeakDetector.
masaru-kimura-hacarus Sep 10, 2025
fb0ff43
Introduce TorchSharp.TensorLeakDetector.
masaru-kimura-hacarus Sep 10, 2025
512a2ba
Declare TorchSharp.Scalar more explicitly.
masaru-kimura-hacarus Sep 11, 2025
06b9e45
Declare TorchSharp.Scalar explicitly.
masaru-kimura-hacarus Sep 11, 2025
0fb172d
Update Adadelta.step.
masaru-kimura-hacarus Sep 11, 2025
645523c
Update Adagrad.step.
masaru-kimura-hacarus Sep 11, 2025
3f0d806
Update Adam.step.
masaru-kimura-hacarus Sep 11, 2025
19d87cc
Update Adamax.step.
masaru-kimura-hacarus Sep 11, 2025
68f00b2
Update ASGD.step.
masaru-kimura-hacarus Sep 11, 2025
4cc71c6
Update NAdam.step.
masaru-kimura-hacarus Sep 11, 2025
997c679
Update RAdam.step.
masaru-kimura-hacarus Sep 11, 2025
bc130ff
Update RMSProp.step.
masaru-kimura-hacarus Sep 11, 2025
64520bc
Update SGD.step.
masaru-kimura-hacarus Sep 11, 2025
2f3c9e3
Update griffinlim.
masaru-kimura-hacarus Sep 11, 2025
d168161
Update AdamW.step.
masaru-kimura-hacarus Sep 11, 2025
90f4065
Update BatchNorm.forward.
masaru-kimura-hacarus Sep 12, 2025
a86d880
Update torch.normal.
masaru-kimura-hacarus Sep 12, 2025
6e54694
Update torchvision.utils.save_image.
masaru-kimura-hacarus Sep 12, 2025
67300a9
Update Rprop.step.
masaru-kimura-hacarus Sep 12, 2025
f252383
Update torchvision.ops.stochastic_depth.
masaru-kimura-hacarus Sep 12, 2025
1523386
Update torchvision.utils.make_grid.
masaru-kimura-hacarus Sep 12, 2025
bfc1f2b
Update TorchSharp.Modules.ExpRelaxedCategorical.log_prob.
masaru-kimura-hacarus Sep 12, 2025
3ec44e4
Update torchvision.transforms.functional.convert_image_dtype.
masaru-kimura-hacarus Sep 12, 2025
ed5c72d
Update torchaudio.functional.griffinlim.
masaru-kimura-hacarus Sep 12, 2025
c8fb482
Update torchaudio.functional._get_sinc_resample_kernel.
masaru-kimura-hacarus Sep 12, 2025
ae57dd6
Use THSTensor_square{,_}.
masaru-kimura-hacarus Sep 12, 2025
9cf47bf
Update torchaudio.functional.spectrogram.
masaru-kimura-hacarus Sep 12, 2025
b9e9da4
Update torchaudio.functional.inverse_spectrogram.
masaru-kimura-hacarus Sep 12, 2025
36983ab
Use torch.Tensor.square().
masaru-kimura-hacarus Sep 12, 2025
6e0c00b
Update InverseMelScale.forward.
masaru-kimura-hacarus Sep 12, 2025
d2a533f
Use torch.Tensor.square.
masaru-kimura-hacarus Sep 12, 2025
f9f81c5
Use torch.Tensor.square.
masaru-kimura-hacarus Sep 12, 2025
2be26a3
Update TorchSharp.torchvision.AdjustGamma.call.
masaru-kimura-hacarus Sep 16, 2025
c717fb3
Update torchvision.transforms.functional.adjust_gamma.
masaru-kimura-hacarus Sep 16, 2025
b266f02
Update torchvision.ops.sigmoid_focal_loss.
masaru-kimura-hacarus Sep 16, 2025
39b5091
Update torchvision.ops.sigmoid_focal_loss.
masaru-kimura-hacarus Sep 16, 2025
359e846
Update TorchSharp.Modules.PReLU constructor.
masaru-kimura-hacarus Sep 16, 2025
e4ffef2
Update TorchSharp.Modules.Rprop.State.Initialize.
masaru-kimura-hacarus Sep 16, 2025
995ab18
Update torchvision.transforms.functional.autocontrast.
masaru-kimura-hacarus Sep 16, 2025
c830437
Update torch.nn.functional.threshold.
masaru-kimura-hacarus Sep 16, 2025
81d021a
Update torch.Tensor.softplus.
masaru-kimura-hacarus Sep 16, 2025
3761a60
Update torch.Tensor.celu{,_}.
masaru-kimura-hacarus Sep 16, 2025
c69d041
Update torch.nn.functional.celu.
masaru-kimura-hacarus Sep 16, 2025
5daa318
Use torch.Tensor.elu_.
masaru-kimura-hacarus Sep 16, 2025
fdf7c92
Update torch.Tensor.elu{,_}.
masaru-kimura-hacarus Sep 16, 2025
3e38c0f
Update torch.nn.functional.hardtanh.
masaru-kimura-hacarus Sep 16, 2025
275eb77
Update torch.nn.functional.leaky_relu.
masaru-kimura-hacarus Sep 16, 2025
8483baa
Update AdversarialExampleGeneration.Attack.
masaru-kimura-hacarus Sep 16, 2025
af54440
Update TorchSharp.Modules.Dirichlet.mode.
masaru-kimura-hacarus Sep 16, 2025
7de5fa3
Update torch.distributions.Distribution.ClampProbs.
masaru-kimura-hacarus Sep 16, 2025
9ce8889
Update TorchSharp.Modules.NegativeBinomial.mode.
masaru-kimura-hacarus Sep 16, 2025
b6dd968
Update TorchSharp.Modules.Pareto.mean.
masaru-kimura-hacarus Sep 16, 2025
afcee2f
Update TorchSharp.Modules.Pareto.variance.
masaru-kimura-hacarus Sep 16, 2025
e5fb29d
Update torch.distributions.transforms.SigmoidTransform.
masaru-kimura-hacarus Sep 16, 2025
ce65592
Update torchvision.transforms.functional.Blend.
masaru-kimura-hacarus Sep 16, 2025
3d7c114
Update torchvision.ops.nms.
masaru-kimura-hacarus Sep 16, 2025
51d71c4
Update torchvision.ops.generalized_box_iou.
masaru-kimura-hacarus Sep 16, 2025
6d45bec
Update torchvision.ops._box_inter_union.
masaru-kimura-hacarus Sep 16, 2025
4dd9545
Update torchvision.ops._box_diou_iou.
masaru-kimura-hacarus Sep 16, 2025
d4a7d2d
Update torch.utils.tensorboard.Summary.image.
masaru-kimura-hacarus Sep 16, 2025
1bf7c45
Update torch.utils.tensorboard.Summary.video.
masaru-kimura-hacarus Sep 16, 2025
19fd88f
Update TorchSharp.Modules.FisherSnedecor.rsample.
masaru-kimura-hacarus Sep 16, 2025
29d5d2b
Update TorchSharp.Modules.Gamma.mode.
masaru-kimura-hacarus Sep 16, 2025
1bc5cb1
Update TorchSharp.Modules.Gamma.rsample.
masaru-kimura-hacarus Sep 16, 2025
93eb240
Update TorchSharp.Modules.Uniform.cdf.
masaru-kimura-hacarus Sep 16, 2025
b70ed58
Use torch.Tensor.clamp_{max,min}_.
masaru-kimura-hacarus Sep 16, 2025
ccb1678
Declare TorchSharp.Scalar explicitly.
masaru-kimura-hacarus Sep 16, 2025
34cd0ac
Update TorchSharp.Modules.GaussianNLLLoss.forward.
masaru-kimura-hacarus Sep 16, 2025
fd3b782
Use torch.tensor.
masaru-kimura-hacarus Sep 16, 2025
9336c22
Update torch.distributions.constraints._OneHot.check.
masaru-kimura-hacarus Sep 16, 2025
1d80bf8
Update torch.distributions.constraints._PositiveDefinite.check.
masaru-kimura-hacarus Sep 16, 2025
5c2e631
Use torch.tensor.
masaru-kimura-hacarus Sep 16, 2025
fc1b893
Update torch.distributions.constraints._PositiveSemiDefinite.check.
masaru-kimura-hacarus Sep 16, 2025
d552c5a
Update torch.distributions.constraints._CorrCholesky.check.
masaru-kimura-hacarus Sep 16, 2025
636303a
Update TransformerModel.GenerateSquareSubsequentMask.
masaru-kimura-hacarus Sep 16, 2025
6c8656a
Update TorchSharp.Modules.Tacotron2.forward.
masaru-kimura-hacarus Sep 16, 2025
bb92724
Update TorchSharp.Modules.Tacotron2.Attention.forward.
masaru-kimura-hacarus Sep 16, 2025
0858332
Update TorchSharp.Modules.NegativeBinomial.log_prob.
masaru-kimura-hacarus Sep 16, 2025
8880435
Simplify.
masaru-kimura-hacarus Sep 16, 2025
23f60dc
Add more operators.
masaru-kimura-hacarus Sep 16, 2025
f5f7c5c
Add more operators.
masaru-kimura-hacarus Sep 17, 2025
c2da2a7
Call PrintValue w/ explicitly declared TorchSharp.Scalar.
masaru-kimura-hacarus Sep 17, 2025
2c469da
Update TensorExtensionMethods.To*(this Tensor value).
masaru-kimura-hacarus Sep 17, 2025
f52c6c4
Introduce torch.Tensor.fill_ overloads.
masaru-kimura-hacarus Sep 17, 2025
cee753b
Update TorchSharp.Modules.Rprop.
masaru-kimura-hacarus Sep 17, 2025
b077e6a
Introduce torch.Tensor.index_put_ overloads.
masaru-kimura-hacarus Sep 17, 2025
ea78740
Introduce torch.Tensor.index_add{,_} overloads.
masaru-kimura-hacarus Sep 17, 2025
16f7e9b
Introduce torch.Tensor.index_fill{,_} overloads.
masaru-kimura-hacarus Sep 17, 2025
effda51
Introduce torch.Tensor.threshold{,_} overloads.
masaru-kimura-hacarus Sep 18, 2025
48281ed
Update torch.nn.functional.threshold.
masaru-kimura-hacarus Sep 18, 2025
e90139f
Introduce torch.Tensor.softplus overloads.
masaru-kimura-hacarus Sep 18, 2025
1287b2d
Add more torch.Tensor.celu{,_} overloads.
masaru-kimura-hacarus Sep 18, 2025
3da9849
Update torch.nn.functional.celu.
masaru-kimura-hacarus Sep 18, 2025
2a2b66e
Add more torch.Tensor.elu{,_} overloads.
masaru-kimura-hacarus Sep 18, 2025
41a0db8
Introduce torch.Tensor.hardtanh{,_} overloads.
masaru-kimura-hacarus Sep 18, 2025
b230428
Update torch.nn.functional.hardtanh.
masaru-kimura-hacarus Sep 18, 2025
789555b
Introduce torch.Tensor.leaky_relu{,_} overloads.
masaru-kimura-hacarus Sep 18, 2025
48e0ca8
Update torch.nn.functional.leaky_relu.
masaru-kimura-hacarus Sep 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Examples.Utils/Examples.Utils.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

<ItemGroup>
<PackageReference Include="SharpZipLib" Version="1.4.0" />
<PackageReference Include="SixLabors.ImageSharp" Version="3.1.7" />
<PackageReference Include="SixLabors.ImageSharp" Version="3.1.11" />
</ItemGroup>

<ItemGroup>
Expand Down
4 changes: 3 additions & 1 deletion src/Examples/AdversarialExampleGeneration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ internal static void Main(string[] args)
private static Tensor Attack(Tensor image, double ε, Tensor data_grad)
{
using (var sign = data_grad.sign()) {
var perturbed = (image + ε * sign).clamp(0.0, 1.0);
using var zero_scalar = 0.0.ToScalar();
using var one_scalar = 1.0.ToScalar();
var perturbed = (image + ε * sign).clamp(zero_scalar, one_scalar);
return perturbed;
}
}
Expand Down
10 changes: 7 additions & 3 deletions src/Examples/SequenceToSequence.cs
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,14 @@ public TransformerModel(long ntokens, long ninputs, long nheads, long nhidden, l

public Tensor GenerateSquareSubsequentMask(long size)
{
var mask = (torch.ones(new long[] { size, size }) == 1).triu().transpose(0, 1);
using var zero_scalar = 0.ToScalar();
using var one_scalar = 1.ToScalar();
using var float_negative_infinity_scalar = float.NegativeInfinity.ToScalar();
using var float_zero_scalar = 0.0f.ToScalar(); // FIXME: Equivalent to zero_scalar?
var mask = (torch.ones(new long[] { size, size }) == one_scalar).triu().transpose(0, 1);
return mask.to_type(ScalarType.Float32)
.masked_fill(mask == 0, float.NegativeInfinity)
.masked_fill(mask == 1, 0.0f).to(device);
.masked_fill(mask == zero_scalar, float_negative_infinity_scalar)
.masked_fill(mask == one_scalar, float_zero_scalar).to(device);
}

private void InitWeights()
Expand Down
4 changes: 4 additions & 0 deletions src/Native/LibTorchSharp/THSTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1220,6 +1220,10 @@ EXPORT_API(Tensor) THSTensor_sqrt(const Tensor tensor);

EXPORT_API(void) THSTensor_sqrt_(const Tensor tensor);

EXPORT_API(Tensor) THSTensor_square(const Tensor tensor);

EXPORT_API(void) THSTensor_square_(const Tensor tensor);

EXPORT_API(Tensor) THSTensor_std(const Tensor tensor, const bool unbiased);

EXPORT_API(Tensor) THSTensor_std_along_dimensions(const Tensor tensor, const int64_t* dimensions, int length, bool unbiased, bool keepdim);
Expand Down
10 changes: 10 additions & 0 deletions src/Native/LibTorchSharp/THSTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,16 @@ void THSTensor_sqrt_(const Tensor tensor)
CATCH(tensor->sqrt_();)
}

Tensor THSTensor_square(const Tensor tensor)
{
CATCH_TENSOR(tensor->square());
}

void THSTensor_square_(const Tensor tensor)
{
CATCH(tensor->square_();)
}

Tensor THSTensor_sign(const Tensor tensor)
{
CATCH_TENSOR(tensor->sign());
Expand Down
42 changes: 25 additions & 17 deletions src/TorchAudio/Functional.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,15 @@ public static torch.Tensor spectrogram(torch.Tensor waveform, long pad, torch.Te
spec_f = spec_f.reshape(spec_shape);

if (normalized) {
spec_f /= window.pow(2.0).sum().sqrt();
spec_f /= window.square().sum().sqrt();
}

if (power.HasValue) {
if (power.Value == 1.0) {
spec_f = spec_f.abs();
} else {
spec_f = spec_f.abs().pow(power.Value);
using var power_scalar = power.Value.ToScalar();
spec_f = spec_f.abs().pow(power_scalar); // FIXME: Call torch.Tensor.square if power.Value == 2.0?
}
}

Expand Down Expand Up @@ -112,7 +113,7 @@ public static torch.Tensor inverse_spectrogram(torch.Tensor spectrogram, long? l
using (var d = torch.NewDisposeScope()) {

if (normalized) {
spectrogram = spectrogram * window.pow(2.0).sum().sqrt();
spectrogram = spectrogram * window.square().sum().sqrt();
}

// pack batch
Expand Down Expand Up @@ -180,23 +181,24 @@ public static Tensor griffinlim(Tensor specgram, Tensor window, long n_fft, long
throw new ArgumentOutOfRangeException($"momentum must be in range [0, 1). Found: {momentum}");
}
momentum = momentum / (1 + momentum);
var need_momentum = momentum > 0.0;
using var momentum_scalar = (need_momentum) ? momentum.ToScalar() : null;

// pack batch
var shape = specgram.size();
specgram = specgram.reshape(new long[] { -1, shape[shape.Length - 2], shape[shape.Length - 1] });

specgram = specgram.pow(1 / power);
using var exponent_scalar = (1 / power).ToScalar();
specgram = specgram.pow(exponent_scalar); // FIXME: Use inplace ops? Skip if power == 1?

// initialize the phase
Tensor angles;
if (rand_init) {
angles = torch.rand(specgram.size(), dtype: _get_complex_dtype(specgram.dtype), device: specgram.device);
} else {
angles = torch.full(specgram.size(), 1, dtype: _get_complex_dtype(specgram.dtype), device: specgram.device);
}
var angles = (rand_init)
? torch.rand(specgram.size(), dtype: _get_complex_dtype(specgram.dtype), device: specgram.device)
: torch.ones(specgram.size(), dtype: _get_complex_dtype(specgram.dtype), device: specgram.device);

// And initialize the previous iterate to 0
var tprev = torch.tensor(0.0, dtype: specgram.dtype, device: specgram.device);
using var eps_scalar = (1e-16).ToScalar();
for (int i = 0; i < n_iter; i++) {
// Invert with our current estimate of the phases
var inverse = torch.istft(
Expand All @@ -218,10 +220,10 @@ public static Tensor griffinlim(Tensor specgram, Tensor window, long n_fft, long

// Update our phase estimates
angles = rebuilt;
if (momentum > 0.0) {
angles = angles - tprev.mul_(momentum);
if (need_momentum) {
angles = angles - tprev.mul_(momentum_scalar!); // FIXME: Use inplace ops?
}
angles = angles.div(angles.abs().add(1e-16));
angles = angles.div(angles.abs().add(eps_scalar));

// Store the previous iterate
tprev = rebuilt;
Expand Down Expand Up @@ -528,18 +530,23 @@ internal static (torch.Tensor, int) _get_sinc_resample_kernel(int orig_freq, int
if (lowpass_filter_width <= 0) {
throw new ArgumentOutOfRangeException();
}
using var min_scalar = (-lowpass_filter_width).ToScalar();
using var max_scalar = lowpass_filter_width.ToScalar();

var kernels_list = new List<torch.Tensor>();
double base_freq = Math.Min(orig_freq, new_freq);
base_freq *= rolloff;

var width = (int)Math.Ceiling(((double)lowpass_filter_width) * orig_freq / base_freq);
var idx_dtype = dtype ?? torch.float64;
var idx = torch.arange(-width, width + orig_freq, device: device, dtype: idx_dtype);
using var start_scalar = (-width).ToScalar();
using var stop_scalar = (width + orig_freq).ToScalar();
var idx = torch.arange(start_scalar, stop_scalar, device: device, dtype: idx_dtype);

using var zero_scalar = 0.ToScalar();
for (int i = 0; i < new_freq; i++) {
var t = (-i / new_freq + idx / orig_freq) * base_freq;
t = t.clamp_(-lowpass_filter_width, lowpass_filter_width);
t = t.clamp_(min_scalar, max_scalar);

torch.Tensor window;
if (resampling_method == ResamplingMethod.sinc_interpolation) {
Expand All @@ -554,13 +561,14 @@ internal static (torch.Tensor, int) _get_sinc_resample_kernel(int orig_freq, int
}
t *= Math.PI;
// Tensor.to(Tensor) of TorchSharp desn't change dtype.
var kernel = torch.where(t == 0, torch.tensor(1.0).to(t).type_as(t), torch.sin(t) / t);
var kernel = torch.where(t == zero_scalar, torch.tensor(1.0).to(t).type_as(t), torch.sin(t) / t);
kernel.mul_(window);
kernels_list.Add(kernel);
}

var scale = ((double)base_freq) / orig_freq;
var kernels = torch.stack(kernels_list.ToArray()).view(new_freq, 1, -1).mul_(scale);
using var scale_scalar = scale.ToScalar();
var kernels = torch.stack(kernels_list.ToArray()).view(new_freq, 1, -1).mul_(scale_scalar);
if (dtype == null) {
kernels = kernels.to(torch.float32);
}
Expand Down
2 changes: 1 addition & 1 deletion src/TorchAudio/Modules/HuBERTPretrainModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public override (Tensor?, Tensor?, Tensor) forward(
if (this.feature_grad_mult != null && this.feature_grad_mult < 1.0) {
x = Wav2Vec2Model.GradMultiply.apply(x, this.feature_grad_mult.Value);
}
var features_pen = x.@float().pow(2).mean();
var features_pen = x.@float().square().mean();
if (lengths is not null) {
padding_mask = Wav2Vec2Model._get_padding_mask(x, lengths);
} else {
Expand Down
11 changes: 7 additions & 4 deletions src/TorchAudio/Modules/Tacotron2.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,11 @@ public override (Tensor, Tensor, Tensor, Tensor) forward(
mask = mask.expand(this.n_mels, mask.size(0), mask.size(1));
mask = mask.permute(1, 0, 2);

mel_specgram = mel_specgram.masked_fill(mask, 0.0);
mel_specgram_postnet = mel_specgram_postnet.masked_fill(mask, 0.0);
gate_outputs = gate_outputs.masked_fill(mask[TensorIndex.Colon, 0, TensorIndex.Colon], 1e3);
using var zero_scalar = 0.0.ToScalar();
mel_specgram = mel_specgram.masked_fill(mask, zero_scalar);
mel_specgram_postnet = mel_specgram_postnet.masked_fill(mask, zero_scalar);
using var eps_scalar = 1e3.ToScalar();
gate_outputs = gate_outputs.masked_fill(mask[TensorIndex.Colon, 0, TensorIndex.Colon], eps_scalar);
}

return (mel_specgram, mel_specgram_postnet, gate_outputs, alignments);
Expand Down Expand Up @@ -334,7 +336,8 @@ public override (Tensor, Tensor) forward(
{
var alignment = this._get_alignment_energies(attention_hidden_state, processed_memory, attention_weights_cat);

alignment = alignment.masked_fill(mask, this.score_mask_value);
using var score_mask_value_scalar = this.score_mask_value.ToScalar();
alignment = alignment.masked_fill(mask, score_mask_value_scalar);

var attention_weights = F.softmax(alignment, dim: 1);
var attention_context = torch.bmm(attention_weights.unsqueeze(1), memory);
Expand Down
7 changes: 4 additions & 3 deletions src/TorchAudio/Transforms/InverseMelScale.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,18 +96,19 @@ public override Tensor forward(Tensor melspec)
learningRate: 0.1, momentum: 0.9);

var loss = float.PositiveInfinity;
using var zero_scalar = 0.ToScalar();
for (long i = 0; i < this.max_iter; i++) {
using var d2 = torch.NewDisposeScope();

optim.zero_grad();
var diff = melspec - specgram.matmul(this.fb);
var new_loss = diff.pow(2).sum(dim: -1).mean();
var new_loss = diff.square().sum(dim: -1).mean();
// take sum over mel-frequency then average over other dimensions
// so that loss threshold is applied par unit timeframe
new_loss.backward();
optim.step();
using (torch.no_grad())
specgram.set_(specgram.clamp(min: 0));
specgram.set_(specgram.clamp(min: zero_scalar));

var new_loss_value = new_loss.item<float>();
if (new_loss_value < this.tolerance_loss || Math.Abs(loss - new_loss_value) < this.tolerance_change) {
Expand All @@ -117,7 +118,7 @@ public override Tensor forward(Tensor melspec)
}

specgram.requires_grad_(false);
var specgram_tensor = specgram.clamp(min: 0).transpose(-1, -2);
var specgram_tensor = specgram.clamp(min: zero_scalar).transpose(-1, -2);

// unpack batch
shape[shape.Length - 2] = freq;
Expand Down
2 changes: 1 addition & 1 deletion src/TorchSharp/Distributions/Beta.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public override Tensor variance {
get {
using var _ = NewDisposeScope();
var total = concentration0 + concentration1;
return (concentration1 * concentration0 / (total.pow(2) * (total + 1))).MoveToOuterDisposeScope();
return (concentration1 * concentration0 / (total.square() * (total + 1))).MoveToOuterDisposeScope();
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/TorchSharp/Distributions/Cauchy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public override Tensor rsample(params long[] sample_shape)
/// </summary>
/// <param name="value"></param>
public override Tensor log_prob(Tensor value) =>
WrappedTensorDisposeScope(() => -Math.Log(Math.PI) - scale.log() - (((value - loc) / scale).pow(2)).log1p());
WrappedTensorDisposeScope(() => -Math.Log(Math.PI) - scale.log() - (((value - loc) / scale).square()).log1p());

/// <summary>
/// Returns entropy of distribution, batched over batch_shape.
Expand Down
16 changes: 10 additions & 6 deletions src/TorchSharp/Distributions/Constraints.cs
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,10 @@ public _OneHot() : base(true, 1) { }

public override Tensor check(Tensor value)
{
var is_boolean = (value == 0) | (value == 1);
var is_normalized = value.sum(-1).eq(1);
using var zero_scalar = 0.ToScalar();
using var one_scalar = 1.ToScalar();
var is_boolean = (value == zero_scalar) | (value == one_scalar);
var is_normalized = value.sum(-1).eq(one_scalar);
return is_boolean.all(-1) & is_normalized;
}
}
Expand Down Expand Up @@ -433,9 +435,9 @@ public _CorrCholesky() : base(false, 2) { }

public override Tensor check(Tensor value)
{
var tol = torch.finfo(value.dtype).eps * value.size(-1) * 10; // 10 is an adjustable fudge factor
using var tol_scalar = (torch.finfo(value.dtype).eps * value.size(-1) * 10).ToScalar(); // 10 is an adjustable fudge factor
var row_norm = torch.linalg.norm(value.detach(), dims: new[] { -1L });
var unit_row_norm = (row_norm - 1.0).abs().le(tol).all(dim: -1);
var unit_row_norm = (row_norm - 1.0).abs().le(tol_scalar).all(dim: -1);
return lc.check(value) & unit_row_norm;
}

Expand Down Expand Up @@ -489,7 +491,8 @@ public override Tensor check(Tensor value)
var sym_check = base.check(value);
if (!sym_check.all().item<bool>())
return sym_check;
return torch.linalg.eigvalsh(value).ge(0).all(-1);
using var zero_scalar = 0.ToScalar();
return torch.linalg.eigvalsh(value).ge(zero_scalar).all(-1);
}
}

Expand All @@ -503,7 +506,8 @@ public override Tensor check(Tensor value)
var sym_check = base.check(value);
if (!sym_check.all().item<bool>())
return sym_check;
return torch.linalg.cholesky_ex(value).info.eq(0);
using var zero_scalar = 0.ToScalar();
return torch.linalg.cholesky_ex(value).info.eq(zero_scalar);
}
}

Expand Down
8 changes: 5 additions & 3 deletions src/TorchSharp/Distributions/Dirichlet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ public override Tensor mode
{
get {
using var _ = NewDisposeScope();
var concentrationm1 = (concentration - 1).clamp(min: 0.0);
using var zero_scalar = 0.0.ToScalar();
var concentrationm1 = (concentration - 1).clamp(min: zero_scalar);
var mode = concentrationm1 / concentrationm1.sum(-1, true);
var mask = (concentration < 1).all(dim: -1);
using var one_scalar = 1.ToScalar();
var mask = (concentration < one_scalar).all(dim: -1);
mode[mask] = torch.nn.functional.one_hot(mode[mask].argmax(dim: -1), concentrationm1.shape[concentrationm1.ndim-1]).to(mode);
return mode.MoveToOuterDisposeScope();
}
Expand All @@ -40,7 +42,7 @@ public override Tensor variance {
get {
using var _ = NewDisposeScope();
var con0 = concentration.sum(-1, true);
return (concentration * (con0 - concentration) / (con0.pow(2) * (con0 + 1))).MoveToOuterDisposeScope();
return (concentration * (con0 - concentration) / (con0.square() * (con0 + 1))).MoveToOuterDisposeScope();
}
}

Expand Down
10 changes: 8 additions & 2 deletions src/TorchSharp/Distributions/Distribution.cs
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,16 @@ protected Tensor ProbsToLogits(Tensor probs, bool isBinary = false)
protected Tensor ClampProbs(Tensor probs)
{
var eps = torch.finfo(probs.dtype).eps;
return probs.clamp(eps, 1 - eps);
using var eps_scalar = eps.ToScalar();
using var eps_bar_scalar = (1 - eps).ToScalar();
return probs.clamp(eps_scalar, eps_bar_scalar);
}

protected Tensor ClampByZero(Tensor x) => (x.clamp_min(0) + x - x.clamp_max(0)) / 2;
protected Tensor ClampByZero(Tensor x)
{
using var zero_scalar = 0.ToScalar();
return (x.clamp_min(zero_scalar) + x - x.clamp_max(zero_scalar)) / 2;
}

protected torch.Generator generator;
bool disposedValue;
Expand Down
4 changes: 3 additions & 1 deletion src/TorchSharp/Distributions/ExpRelaxedCategorical.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ public override Tensor log_prob(Tensor value)
var logitsValue = broadcast_tensors(_logits, value);
var logits = logitsValue[0];
value = logitsValue[1];
var log_scale = (torch.full_like(_temperature, K).lgamma() - _temperature.log().mul(-(K - 1)));
using var K_scalar = K.ToScalar();
using var negative_Ksub1_scalar = (-(K - 1)).ToScalar();
var log_scale = torch.full_like(_temperature, K_scalar).lgamma() - _temperature.log().mul(negative_Ksub1_scalar); // FIXME: Use inplace ops?
var score = logits - value.mul(_temperature);
score = (score - score.logsumexp(dim: -1, keepdim: true)).sum(-1);
return (score + log_scale).MoveToOuterDisposeScope();
Expand Down
2 changes: 1 addition & 1 deletion src/TorchSharp/Distributions/Exponential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public class Exponential : torch.distributions.ExponentialFamily
/// <summary>
/// The variance of the distribution
/// </summary>
public override Tensor variance => rate.pow(2);
public override Tensor variance => rate.square();

/// <summary>
/// The standard deviation of the distribution
Expand Down
Loading