Skip to content

Conversation

@Qazalbash
Copy link
Contributor

This PR contains the resolution of mypy errors passed by #2032, in numpyro.distributions.transforms module.

There are two cases in particular which I am unable to resolve. You can see them by running the mypy.

log_abs_det_jacobian of many transforms have unused parameters. I have typed them as union of UnusedParam and some appropriate numpy/jax type.

Many cases were unresolvable, like, __eq__ method expects bool as return type, but & operation between arrays return array of type bool, which conflicts with the return type, therefore I have added the tag to ignore them. You will find similar tags in the file.

@juanitorduz
Copy link
Collaborator

This looks great, thanks! There are some minor errors

numpyro/distributions/transforms.py:82: error: Missing positional argument "x" in call to "__call__" of "TransformT"  [call-arg]
Installing missing stub packages:
numpyro/distributions/transforms.py:84: error: Name "inv" already defined on line 80  [no-redef]
numpyro/distributions/transforms.py:85: error: Incompatible types in assignment (expression has type "ReferenceType[None]", variable has type "TransformT | None")  [assignment]
numpyro/distributions/transforms.py:86: error: Incompatible return value type (got "Array | Any | None", expected "TransformT")  [return-value]
numpyro/distributions/transforms.py:1550: error: Item "ndarray[tuple[Any, ...], dtype[Any]]" of "ndarray[tuple[Any, ...], dtype[Any]] | Array" has no attribute "at"  [union-attr]

Do you need some help with these :) ?

@Qazalbash
Copy link
Contributor Author

Qazalbash commented Aug 29, 2025

numpyro/distributions/transforms.py:82: error: Missing positional argument "x" in call to "__call__" of "TransformT"  [call-arg]
Installing missing stub packages:
numpyro/distributions/transforms.py:84: error: Name "inv" already defined on line 80  [no-redef]
numpyro/distributions/transforms.py:85: error: Incompatible types in assignment (expression has type "ReferenceType[None]", variable has type "TransformT | None")  [assignment]
numpyro/distributions/transforms.py:86: error: Incompatible return value type (got "Array | Any | None", expected "TransformT")  [return-value]

These errors are from numpyro/distributions/transforms.py#L77-L86, and I am not able to understand the significance of different conditions and the weak reference. I think you can look into this matter.

I will take this one,

numpyro/distributions/transforms.py:1550: error: Item "ndarray[tuple[Any, ...], dtype[Any]]" of "ndarray[tuple[Any, ...], dtype[Any]] | Array" has no attribute "at"  [union-attr]

Thank you for offering help ❤️.

@juanitorduz
Copy link
Collaborator

juanitorduz commented Aug 29, 2025

ok! Sounds like a plan! I will try to look at it in the next days :)

@juanitorduz
Copy link
Collaborator

Hey @Qazalbash I gave it a try as in d57d9a6 . MyPy is happy now, maybe you can try it ? The only key change was self.inv() -> self.inv which I think makes more sense, let's see if the tests complain ;)

juanitorduz referenced this pull request Aug 30, 2025
@Qazalbash
Copy link
Contributor Author

@juanitorduz Thanks for the changes, mypy is happy now.

Do I need to remove the plugin from pyproject.toml?

@juanitorduz
Copy link
Collaborator

It's depreciated so it's safe to remove

@Qazalbash Qazalbash requested a review from fehiepsi September 4, 2025 12:04
@juanitorduz
Copy link
Collaborator

juanitorduz commented Sep 4, 2025

ok! I think the tests are failing because a new NNX release and changes in nnx.merge, see https://github.com/google/flax/releases/tag/v0.11.2

The other tests FAILED test/test_distributions.py::test_entropy_samples I am not sure about.

@juanitorduz
Copy link
Collaborator

Here is a patch for the first errors #2067

@Qazalbash
Copy link
Contributor Author

Here's another patch #2069 😸

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Qazalbash, I think we can get around issues by using NumLike just at some specific places. It's fine to use StrictArray at those arrays with dim >= 1. NonScalarArray is a good name for it I guess.

@juanitorduz
Copy link
Collaborator

@Qazalbash do you need any support here to bring this one to the finish line :) ?

@Qazalbash
Copy link
Contributor Author

@Qazalbash do you need any support here to bring this one to the finish line :) ?

@juanitorduz thanks, but not quite right now. Hopefully, I will sit down tonight and tomorrow to complete it.

@Qazalbash
Copy link
Contributor Author

@juanitorduz I thought it would be easy to fix, but to my surprise, these changes are generating more errors than before. Would you like to take over this issue?

@juanitorduz
Copy link
Collaborator

@juanitorduz I thought it would be easy to fix, but to my surprise, these changes are generating more errors than before. Would you like to take over this issue?

hey, sure! What about if you incrementally push the easiest suggestions (that still work) until you face a problem and then we take it from there? 🙏

@Qazalbash
Copy link
Contributor Author

@juanitorduz I think I have fixed all! See 1e2e670

@Qazalbash Qazalbash requested a review from fehiepsi September 21, 2025 22:32
@juanitorduz
Copy link
Collaborator

Amazing @Qazalbash ! Thank you!

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes look great! I just have minor comments. Thank you so much!

@juanitorduz
Copy link
Collaborator

hey @Qazalbash this is finally almost there. Do you need a hand with the last changes?

@Qazalbash
Copy link
Contributor Author

Qazalbash commented Oct 13, 2025

Hi @juanitorduz,

I have explained, the changes @fehiepsi is asking, are violating the Liskov substitution principle. I am waiting for his response.


As per the mypy site,

It’s unsafe to override a method with a more specific argument type, as it violates the Liskov substitution principle. For return types, it’s unsafe to override a method with a more general return type.

NonStaticArray is more specific than NumLike, therefore, we can not use NumLike in the parent class and then override it with NonStaticArray in some child classes.


IMO, one solution could be to set all the contested types to Any in Transform and let the subclasses decide.

@juanitorduz
Copy link
Collaborator

Thank you for the info (I somehow missed it) 🙇

@fehiepsi
Copy link
Member

Somehow I missed your comment, could you ping me on it? We need to allow scalar numbers at places that do not require non-scalar arrays. My comments was to address them. Please let me know if I miss something.

@fehiepsi
Copy link
Member

For parent classes, we need to use NumLike if possible.

@fehiepsi
Copy link
Member

If there is no wip, i can push the changes to your branch (if you prefer)

@Qazalbash
Copy link
Contributor Author

Qazalbash commented Oct 15, 2025

I have pinged you. I have no solution for this issue, if you have any, please share, I can implement it. You are welcome to update my branch too.

@fehiepsi
Copy link
Member

fehiepsi commented Oct 15, 2025

I still cant see that comment. Maybe a github bug. Let me look into the remaining issues then.

@Qazalbash
Copy link
Contributor Author

Please take a look at the reply of the comment.

@fehiepsi
Copy link
Member

Thanks @Qazalbash! Though I still can't see that reply (it is invisible to me), I understand your point now. I just pushed a commit that using Generic type for Transform. That way we can declare array constraints in the subclass. Let me know if you have any opinion about the changes.

@Qazalbash
Copy link
Contributor Author

Thanks @fehiepsi, for looking into it. I see you also have put types in numpyro.distributions.constraints module, I think we should also update the PR title accordingly.

@fehiepsi
Copy link
Member

I think it's fine. I just updated types of arguments in constraints that need to be NumLike for some transforms. It does not resolve typing issues of constraints I believe.

@juanitorduz
Copy link
Collaborator

Where are these failing tests coming from 🤔 ?

@fehiepsi
Copy link
Member

There are a couple of issues: GaussianCopulaBeta is not compatible with jax 0.7, some numerical issues, and many fails due to my change in eq implementation, which I just reverted.

@fehiepsi fehiepsi merged commit 725e009 into pyro-ppl:master Oct 20, 2025
9 checks passed
@Qazalbash Qazalbash deleted the type-hint-transforms branch October 21, 2025 12:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants