-
Notifications
You must be signed in to change notification settings - Fork 67
feat(models): mapper refactor #574
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
JPXKQX
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've just left some minor comments/questions:
- Some of these comments relate to what should be included in the
_init_graph_mode()function, and whether it could do more than just set the graph mode. - Another point is whether it makes sense to pass the
edge_indexandedge_attrin the init of the mappers (insteadsub_graphandsub_graph_edge_attributes), and whether the selection of the edge attributes should be moved from the mapper to the model.
1591075 to
5c4cdbc
Compare
matschreiner
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Simon!
Thanks for implementing this, it is really cool!
There still seem to be two parallel control flows: one where graphs are provided externally (dynamic case) and one where they’re provided by the GraphProvider (static case). The purpose of this design is to abstract graph handling into a single place, but it’s still split between two places.
I think either the DynamicGraphProvider should be hooked into the “graph stream” somehow, or the static graph should be provided at the same point in the pipeline as the dynamic graphs are now, so both can be passed consistently into the mappers’ forward methods.
matschreiner
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM :)
pass edges and edge attributes to mapper and processor
for more information, see https://pre-commit.ci
|
UPDATE: One integration test is failing: test_restart_from_existing_checkpoint. |
|
|
||
| class SparseProjector(torch.nn.Module): | ||
| """Constructs and applies a sparse projection matrix for mapping features between grids. | ||
| """Applies sparse projection matrix to input tensors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the logic gets moved to the ProjectionGraphProvider, this class could be removed as it essentially just performs a matmul.
Description
As described in #552 refactor mappers to enable both flexible graph behaviour, static, dynamic etc.