Skip to content

Document Python -> Rust Model Translation Best Practices #549

@ralfbiedert

Description

@ralfbiedert

One follow up to #543, which also touches #545 and and answer you've given here (#174).

While you provide instructions how to re-create the weights from a specific Python version, would it be possible to provide a guide how to best replicate their appropriate architectures?

Background

When I tried to follow the Python instructions they worked alright to get the .ot, but then (taking the yolo example) you also need to "magically know" amongst others:

  • correct layers and their ordering
  • getting the right nn::func_t behavior
  • the VarStore / Path labels for each weight set
  • Postprocessing and label information (e.g., coco_classes.rs).
  • ...

For simple networks this seems mildly guessable, but when I tried to re-create yolo3 I already ran into these issues, let alone when I was looking into yolo5, yolo6 or yolo7.

I tried to Python print() the models and using their outputs for guidance for a coarse outline, as well as inspecting the model blobs in a Python debugger, but for the more intricate parts I hit a wall pretty quickly having to step through the actual model source in minute detail, and still not being sure if I end up with the right thing.

Question

So, tl;dr, would you mind sharing or documenting your "best practice" how to convert "arbitrary" models to tch? In particular ...

  1. Before you even start, are there "red flags" to tell right away converting a model won't be worth it?
    1.1 When is it feasible to recreate a model with existing weights?
    1.2. If recreation doesn't work, any thoughts on JIT?
  2. Were there any technical / conversion reasons you picked YOLO_v3_tutorial_from_scratch and not, say, a model from torchhub?
  3. What has been your general workflow to convert something like yolo3-from-scratch?
    3.1 How did you determine the right layers and params (e.g., use a debugger, Python source line-by-line, ...)?
    3.2 Same for custom functions? I assume these you have to get from source in any case.
    3.3 Where do you get the VarStore labels from?
    3.4 Do you have any debugging / QA tips (e.g., to actually verify all parameters / weights are correct w.r.t Python)?

I don't think this has to be overly long, but a few lines might help people to ensure they're on the right track and follow "best practices".

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions