Skip to content

Module.Load throws Mismatched state_dict sizes exception on BatchNorm1d #510

@FusionCarcass

Description

@FusionCarcass

When loading a model from a .dat file exported from python, the Module.Load method throws the exception below. I printed out all of the registered parameters, and the only ones that didn't show up are BatchNorm1d parameters: running_mean, running_var, and num_batches_tracked.

I tried to work around this problem by registering those parameters with the register_parameter function. That eliminates the exception below, but I run into a different issue where the bias is not loaded correctly after registering the other parameters. The bias parameter is still set to torch.zeros(N).

System.ArgumentException
  HResult=0x80070057
  Message=Mismatched state_dict sizes: expected 200, but found 300 entries.
  Source=TorchSharp
  StackTrace:
   at TorchSharp.torch.nn.Module.load(BinaryReader reader, Boolean strict)
   at TorchSharp.torch.nn.Module.load(String location, Boolean strict)
   at OpenHips.Scanners.TorchScanner.LoadFromFile(String fileName) in C:\Users\helpdesk\Desktop\Workspace\repos\open-hips\open-hips-cortex\Scanners\TorchScanner.cs:line 15
   at OpenHips.Program.HandleScan(FileInfo target, FileInfo model) in C:\Users\helpdesk\Desktop\Workspace\repos\open-hips\open-hips-scanner\Program.cs:line 33
   at System.CommandLine.Handler.<>c__DisplayClass2_0`2.<SetHandler>b__0(InvocationContext context)
   at System.CommandLine.Invocation.AnonymousCommandHandler.<>c__DisplayClass2_0.<.ctor>g__Handle|0(InvocationContext context)
   at System.CommandLine.Invocation.AnonymousCommandHandler.<InvokeAsync>d__3.MoveNext()

The load method should probably take into consideration the registered buffers if we are not going to consider running_mean and running_var parameters.

Here are the fixes I tried.

internal class BasicConv1d : Module {
        private readonly Sequential stack;

        public BasicConv1d(int in_channels, int out_channels, int kernel_size, int stride = 1, int padding = 0) : base(String.Empty) {
            BatchNorm1d temp = BatchNorm1d(out_channels);
            //temp.reset_running_stats();
            temp.running_mean = new Parameter(torch.zeros(out_channels, requiresGrad:false), requires_grad:false);
            temp.running_var = new Parameter(torch.zeros(out_channels, requiresGrad:false), requires_grad:false);
            temp.register_parameter("num_batches_tracked", new Parameter(temp.state_dict()["num_batches_tracked"], requires_grad: false));
            temp.register_parameter("running_mean", temp.running_mean);
            temp.register_parameter("running_var", temp.running_var);
            temp.bias = temp.get_parameter("bias"); // Without this line, temp.bias is all zeros after load. With this line, temp.bias is equal to temp.running_var
            this.stack = Sequential(
                Conv1d(in_channels, out_channels, kernel_size, stride:stride, padding:padding, bias:false),
                temp,
                ReLU(inPlace:true)
            );

            this.RegisterComponents();
        }

        public override torch.Tensor forward(torch.Tensor t) {
            return this.stack.forward(t);
        }
    }

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions