use fused multiply-add pointwise ops in chroma #8279
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Micro-optimization for Chroma. Replaces ops with
addcmul
orTensor.addcmul_
where appropriate. Possibly less readable but these ops are mathematically equivalent to what existed prior. I don't know the full details of how addcmul gets lowered but it should be more numerically stable as well (fma has infinite precision intermediates). Intended to try to mimic some of the performance gains that torch.compile gets through pointwise fusion -- this isn't the area with the most gains for that (that honor seems to go to torch's awful RMSNorm implementation that does like a dozen separate kernel launches in a row) but nevertheless doing this does reduce the amount of pointwise ops by a fair bit so it's worth doing. Triton itself also isn't always very reliable about taking opportunities to lower separatemul
andadd
instructions intofma
, so doing this should likely help torch.compile itself out a bit too.For future reference for any further efforts here, applying
@torch.compile
tomath.py:attention
andrmsnorm.py:rms_norm
(which I have confirmed in my case is using the native torch implementation) seems to yield most of the performance gains that compiling either the whole model or individual modules inchroma/layers.py
yields, so those two probably could use some attention for potential gains across multiple models.