-
Notifications
You must be signed in to change notification settings - Fork 214
Description
When loading a model from a .dat file exported from python, the Module.Load method throws the exception below. I printed out all of the registered parameters, and the only ones that didn't show up are BatchNorm1d parameters: running_mean, running_var, and num_batches_tracked.
I tried to work around this problem by registering those parameters with the register_parameter function. That eliminates the exception below, but I run into a different issue where the bias is not loaded correctly after registering the other parameters. The bias parameter is still set to torch.zeros(N).
System.ArgumentException
HResult=0x80070057
Message=Mismatched state_dict sizes: expected 200, but found 300 entries.
Source=TorchSharp
StackTrace:
at TorchSharp.torch.nn.Module.load(BinaryReader reader, Boolean strict)
at TorchSharp.torch.nn.Module.load(String location, Boolean strict)
at OpenHips.Scanners.TorchScanner.LoadFromFile(String fileName) in C:\Users\helpdesk\Desktop\Workspace\repos\open-hips\open-hips-cortex\Scanners\TorchScanner.cs:line 15
at OpenHips.Program.HandleScan(FileInfo target, FileInfo model) in C:\Users\helpdesk\Desktop\Workspace\repos\open-hips\open-hips-scanner\Program.cs:line 33
at System.CommandLine.Handler.<>c__DisplayClass2_0`2.<SetHandler>b__0(InvocationContext context)
at System.CommandLine.Invocation.AnonymousCommandHandler.<>c__DisplayClass2_0.<.ctor>g__Handle|0(InvocationContext context)
at System.CommandLine.Invocation.AnonymousCommandHandler.<InvokeAsync>d__3.MoveNext()
The load method should probably take into consideration the registered buffers if we are not going to consider running_mean and running_var parameters.
Here are the fixes I tried.
internal class BasicConv1d : Module {
private readonly Sequential stack;
public BasicConv1d(int in_channels, int out_channels, int kernel_size, int stride = 1, int padding = 0) : base(String.Empty) {
BatchNorm1d temp = BatchNorm1d(out_channels);
//temp.reset_running_stats();
temp.running_mean = new Parameter(torch.zeros(out_channels, requiresGrad:false), requires_grad:false);
temp.running_var = new Parameter(torch.zeros(out_channels, requiresGrad:false), requires_grad:false);
temp.register_parameter("num_batches_tracked", new Parameter(temp.state_dict()["num_batches_tracked"], requires_grad: false));
temp.register_parameter("running_mean", temp.running_mean);
temp.register_parameter("running_var", temp.running_var);
temp.bias = temp.get_parameter("bias"); // Without this line, temp.bias is all zeros after load. With this line, temp.bias is equal to temp.running_var
this.stack = Sequential(
Conv1d(in_channels, out_channels, kernel_size, stride:stride, padding:padding, bias:false),
temp,
ReLU(inPlace:true)
);
this.RegisterComponents();
}
public override torch.Tensor forward(torch.Tensor t) {
return this.stack.forward(t);
}
}