Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions RELEASENOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ Releases, starting with 9/2/2021, are listed with the most recent release at the

__Breaking Changes__:

- `torchvision.dataset.MNIST` will try more mirrors.
- The thrown exception might be changed when it fails to download `MNIST`, `FashionMNIST` or `KMNIST`.

__API Changes__:

__Bug Fixes__:
Expand Down
17 changes: 16 additions & 1 deletion src/TorchVision/dsets/CIFAR.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,26 @@ protected void DownloadFile(string file, string target, string baseUrl)
lock (_httpClient) {
using var s = _httpClient.GetStreamAsync(netPath).Result;
using var fs = new FileStream(filePath, FileMode.CreateNew);
s.CopyToAsync(fs).Wait();
s.CopyTo(fs);
}
}
}

protected void DownloadFile(string file, string target, IEnumerable<string> baseUrls)
{
var exceptions = new List<Exception>();
foreach (var baseUrl in baseUrls) {
try {
DownloadFile(file, target, baseUrl);
return;
} catch (Exception e) {
exceptions.Add(e);
continue;
}
}
throw new AggregateException($"Error downloading {file}", exceptions);
}

protected static string JoinPaths(string directory, string file)
{
#if NETSTANDARD2_0_OR_GREATER
Expand Down
33 changes: 23 additions & 10 deletions src/TorchVision/dsets/MNIST.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ namespace Modules
/// </summary>
internal class MNIST : DatasetHelper
{
private static string[] Mirrors => new[] {
"http://yann.lecun.com/exdb/mnist/",
"https://ossci-datasets.s3.amazonaws.com/mnist/"
};

/// <summary>
/// Constructor
/// </summary>
Expand All @@ -84,13 +89,13 @@ internal class MNIST : DatasetHelper
/// <param name="download"></param>
/// <param name="transform">Transform for input MNIST image</param>
public MNIST(string root, bool train, bool download = false, torchvision.ITransform transform = null) :
this(root, "mnist", train ? "train" : "t10k", "http://yann.lecun.com/exdb/mnist/", download, transform)
this(root, "mnist", train ? "train" : "t10k", Mirrors, download, transform)
{
}

protected MNIST(string root, string datasetName, string prefix, string baseUrl, bool download, torchvision.ITransform transform)
protected MNIST(string root, string datasetName, string prefix, IEnumerable<string> baseUrls, bool download, torchvision.ITransform transform)
{
if (download) Download(root, baseUrl, datasetName);
if (download) Download(root, baseUrls, datasetName);

this.transform = transform;

Expand Down Expand Up @@ -156,7 +161,7 @@ protected MNIST(string root, string datasetName, string prefix, string baseUrl,
}
}

private void Download(string root, string baseUrl, string dataset)
private void Download(string root, IEnumerable<string> baseUrls, string dataset)
{
#if NETSTANDARD2_0_OR_GREATER
var datasetPath = NSPath.Join(root, dataset);
Expand All @@ -171,10 +176,10 @@ private void Download(string root, string baseUrl, string dataset)
Directory.CreateDirectory(sourceDir);
}

DownloadFile("train-images-idx3-ubyte.gz", sourceDir, baseUrl);
DownloadFile("train-labels-idx1-ubyte.gz", sourceDir, baseUrl);
DownloadFile("t10k-images-idx3-ubyte.gz", sourceDir, baseUrl);
DownloadFile("t10k-labels-idx1-ubyte.gz", sourceDir, baseUrl);
DownloadFile("train-images-idx3-ubyte.gz", sourceDir, baseUrls);
DownloadFile("train-labels-idx1-ubyte.gz", sourceDir, baseUrls);
DownloadFile("t10k-images-idx3-ubyte.gz", sourceDir, baseUrls);
DownloadFile("t10k-labels-idx1-ubyte.gz", sourceDir, baseUrls);

if (!Directory.Exists(targetDir)) {
Directory.CreateDirectory(targetDir);
Expand Down Expand Up @@ -229,6 +234,10 @@ public override Dictionary<string, Tensor> GetTensor(long index)

internal class FashionMNIST : MNIST
{
private static string[] Mirrors => new[] {
"https://github.com/zalandoresearch/fashion-mnist/raw/master/data/fashion/"
};

/// <summary>
/// Constructor
/// </summary>
Expand All @@ -237,13 +246,17 @@ internal class FashionMNIST : MNIST
/// <param name="download"></param>
/// <param name="transform">Transform for input MNIST image</param>
public FashionMNIST(string root, bool train, bool download = false, torchvision.ITransform transform = null) :
base(root, "fashion-mnist", train ? "train" : "t10k", "https://github.com/zalandoresearch/fashion-mnist/raw/master/data/fashion/", download, transform)
base(root, "fashion-mnist", train ? "train" : "t10k", Mirrors, download, transform)
{
}
}

internal class KMNIST : MNIST
{
private static string[] Mirrors => new[] {
"http://codh.rois.ac.jp/kmnist/dataset/kmnist/"
};

/// <summary>
/// Constructor
/// </summary>
Expand All @@ -252,7 +265,7 @@ internal class KMNIST : MNIST
/// <param name="download"></param>
/// <param name="transform">Transform for input MNIST image</param>
public KMNIST(string root, bool train, bool download = false, torchvision.ITransform transform = null) :
base(root, "kmnist", train ? "train" : "t10k", "http://codh.rois.ac.jp/kmnist/dataset/kmnist/", download, transform)
base(root, "kmnist", train ? "train" : "t10k", Mirrors, download, transform)
{
}
}
Expand Down
6 changes: 3 additions & 3 deletions test/TorchSharpTest/LinearAlgebra.cs
Original file line number Diff line number Diff line change
Expand Up @@ -404,19 +404,19 @@ public void SolveTriangularTest()
var A = randn(3, 3).triu_();
var b = randn(3, 4);
var x = linalg.solve_triangular(A, b, upper: true);
Assert.True(A.matmul(x).allclose(b, rtol: 1e-03, atol: 1e-06));
Assert.True(A.matmul(x).allclose(b, rtol: 1e-03, atol: 1e-05));
}
{
var A = randn(2, 3, 3).tril_();
var b = randn(2, 3, 4);
var x = linalg.solve_triangular(A, b, upper: false);
Assert.True(A.matmul(x).allclose(b, rtol: 1e-03, atol: 1e-06));
Assert.True(A.matmul(x).allclose(b, rtol: 1e-03, atol: 1e-05));
}
{
var A = randn(2, 4, 4).tril_();
var b = randn(2, 3, 4);
var x = linalg.solve_triangular(A, b, upper: false, left: false);
Assert.True(x.matmul(A).allclose(b, rtol: 1e-03, atol: 1e-06));
Assert.True(x.matmul(A).allclose(b, rtol: 1e-03, atol: 1e-05));
}
}

Expand Down