Skip to content

Unify scaled INT8 matmul #862

@gau-nernst

Description

@gau-nernst

With the new addition of INT8 mixed-precision training, there are now 2 implementations of scaled INT8 matmul (INT8 matmul + dequant)

I have identified the key differences

intmm_triton.py int8_mm.py
Only fuse act scale Fuse both act scale and weight scale
Scale step is acc_i32 x scale Scale step is cast to fp32 acc_i32.to(f32) x scale.to(f32)
Different autotune configs Different autotune configs

Ideally we should only keep 1. The tedious part is to validate there is no accuracy+speed regression, regardless of which final implementation we will adopt.

Here are the places that use intmm_triton.py

-> Basically ensure INT8 dynamic quantization for Llama and SAM benchmarks don't regress

Here are the places that use int8_mm.py

-> Ensure INT8 mixed-precision training doesn't regress

Another question. Is it ok to change int_scaled_matmul() signature to accept scales for both A and B instead of only for A?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions