diff --git a/benchmarks/bench_galore_fused_kernels.py b/benchmarks/bench_galore_fused_kernels.py deleted file mode 100644 index 3bfa9056bd..0000000000 --- a/benchmarks/bench_galore_fused_kernels.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -import argparse -import os - -import torch -from fused_benchmark_utils import get_benchmark # , make_data - - -def run(args): - dtype = getattr(torch, args.dtype) - allow_tf32 = args.allow_tf32 - torch.backends.cuda.matmul.allow_tf32 = allow_tf32 - M, N = args.M, args.N - rank = args.rank - - # exp_avg, exp_avg2, grad, proj_matrix, params = make_data(M, N, rank, dtype) - - benchmark = get_benchmark(M, N, dtype, allow_tf32=allow_tf32) - save_path = ( - f"benchmark_{M}x{N}_{rank}_{args.dtype}_{'tf32' if allow_tf32 else 'no-tf32'}" - ) - if not os.path.exists(save_path): - os.makedirs(save_path) - print( - f"Running benchmark for {M}x{N}, dtype {args.dtype}, allow_tf32 {allow_tf32}", - flush=True, - ) - benchmark.run(show_plots=False, print_data=True, save_path=save_path) - print(f"Finished benchmark, results saved to {save_path}") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - parser.add_argument( - "--kernel", - choices=["hybrid", "fused", "compiled"], - default="hybrid", - type=str, - help="Kernel to test", - ) - - parser.add_argument( - "--allow_tf32", action="store_true", help="Allow tf32 for matmuls" - ) - parser.add_argument("--M", type=int, default=4096, help="Grad (param) shape M") - parser.add_argument("--N", type=int, default=4096, help="Grad (param) shape N") - parser.add_argument( - "--rank", type=int, default=128, help="Rank of GaLore projection" - ) - parser.add_argument( - "--dtype", - type=str, - choices=["float32", "float16", "bfloat16"], - default="float32", - help="Data type of grad (param) tensors", - ) - - args = parser.parse_args() - run(args) diff --git a/benchmarks/fused_benchmark_utils.py b/benchmarks/fused_benchmark_utils.py deleted file mode 100644 index c1ae0bfac2..0000000000 --- a/benchmarks/fused_benchmark_utils.py +++ /dev/null @@ -1,261 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -import torch -import triton -from triton.testing import do_bench - -from torchao.prototype.galore.kernels.adam_downproj_fused import fused_adam_mm_launcher -from torchao.prototype.galore.kernels.adam_step import triton_adam_launcher -from torchao.prototype.galore.kernels.matmul import triton_mm_launcher -from torchao.prototype.galore.utils import TestGaLoreProjector as GaLoreProjector - -torch.manual_seed(0) - -BETA1 = 0.9 -BETA2 = 0.999 -EPS = 1e-8 -STEP_SIZE = 1e-4 - - -def make_data(M, N, rank, dtype): - grad = torch.randn(M, N, device="cuda", dtype=dtype) - params = torch.randn(M, N, device="cuda", dtype=dtype) - - galore_proj = GaLoreProjector(rank=rank) - galore_proj.update_orthogonal_matrix(grad) - - if M >= N: - exp_avg = torch.randn(M, rank, device="cuda", dtype=dtype) - else: - exp_avg = torch.randn(rank, N, device="cuda", dtype=dtype) - exp_avg2 = exp_avg**2 - - return exp_avg, exp_avg2, grad, galore_proj.ortho_matrix, params - - -def make_copy(*args): - return [t.detach().clone() for t in args] - - -def _ref_op( - grad, - proj_matrix, - exp_avg, - exp_avg2, - params, - beta1=BETA1, - beta2=BETA2, - eps=EPS, - step_size=STEP_SIZE, - **kwargs, -): - # Step 1: Down proj grad - M, N = grad.shape - if M >= N: - a, b = grad, proj_matrix.t() - else: - a, b = proj_matrix.t(), grad - low_rank_grad = a @ b - - # Step 2: update adam state - exp_avg.mul_(beta1).add_(low_rank_grad, alpha=(1.0 - beta1)) - exp_avg2.mul_(beta2).addcmul_(low_rank_grad, low_rank_grad, value=1.0 - beta2) - denom = exp_avg2.sqrt().add_(eps) - low_rank_norm_grad = exp_avg / denom - - # Step 3: project normalized low rank grad to full rank - if M >= N: - a, b = low_rank_norm_grad, proj_matrix - else: - a, b = proj_matrix, low_rank_norm_grad - full_grad_norm = a @ b - - # Finally, update params with updated grad - params.add_(full_grad_norm, alpha=-step_size) - - return exp_avg, exp_avg2, params - - -def _tt_hybrid( - grad, - proj_matrix, - exp_avg, - exp_avg2, - params, - store=True, - step_size=STEP_SIZE, - fp8_fast_accum=False, - allow_tf32=False, -): - M, N = grad.shape - if M >= N: - a, b = grad, proj_matrix.t() - else: - a, b = proj_matrix.t(), grad - low_rank_grad = a @ b - - exp_avg, exp_avg2, norm_grad = triton_adam_launcher( - exp_avg, exp_avg2, low_rank_grad, store=store - ) - - if M >= N: - a, b = low_rank_grad, proj_matrix - else: - a, b = proj_matrix, low_rank_grad - params = triton_mm_launcher( - a, - b, - epilogue_alpha=-step_size, - epilogue_source=params, - allow_tf32=allow_tf32, - fp8_fast_accum=fp8_fast_accum, - ) - return exp_avg, exp_avg2, params - - -def _tt_fused( - grad, - proj_matrix, - exp_avg, - exp_avg2, - params, - store=True, - step_size=STEP_SIZE, - fp8_fast_accum=False, - allow_tf32=False, -): - M, N = grad.shape - - if M >= N: - a, b = grad, proj_matrix.t() - else: - a, b = proj_matrix.t(), grad - exp_avg, exp_avg2, low_rank_grad = fused_adam_mm_launcher( - a, - b, - exp_avg=exp_avg, - exp_avg2=exp_avg2, - store=store, - fp8_fast_accum=fp8_fast_accum, - allow_tf32=allow_tf32, - ) - - if M >= N: - a, b = low_rank_grad, proj_matrix - else: - a, b = proj_matrix, low_rank_grad - params = triton_mm_launcher( - a, - b, - epilogue_alpha=-step_size, - epilogue_source=params, - allow_tf32=allow_tf32, - fp8_fast_accum=fp8_fast_accum, - ) - return exp_avg, exp_avg2, params - - # logging.basicConfig(level=logging.INFO) - - -def get_kernel(kernel): - if kernel == "ref": - op = _ref_op - elif kernel == "ref": - op = torch.compile(_ref_op, fullgraph=True, mode="max-autotune") - elif kernel == "hybrid": - op = _tt_hybrid - elif kernel == "fused": - op = _tt_fused - else: - raise ValueError(f"Unknown kernel {kernel}") - - return lambda *args, **kwargs: op(*args, **kwargs) - - -def get_benchmark( - M, N, dtype, allow_tf32, fp8_fast_accum=False, quantiles=[0.5, 0.2, 0.8] -): - config = triton.testing.Benchmark( - x_names=["rank"], # Argument names to use as an x-axis for the plot - x_vals=[ - 32, - 64, - 128, - 256, - 512, - ], # Different possible values for `x_name` - line_arg="kernel", # Argument name whose value corresponds to a different line in the plot - # Possible values for `line_arg` - line_vals=["torch", "hybrid", "fused", "compiled"], - # Label name for the lines - line_names=["torch", "hybrid", "fused", "compiled"], - # Line styles - styles=[("black", "-"), ("blue", "-"), ("red", "-"), ("green", "-")], - ylabel="ms", # Label name for the y-axis - plot_name=f"Adam Kernel Comparison Grad shape: {M}x{N}, dtype: {dtype}, allow_tf32: {allow_tf32}\nMedian times (ms)", # Name for the plot, used also as a file name for saving the plot. - args={}, - ) - - def benchmark(rank, kernel): - torch.backends.cuda.matmul.allow_tf32 = allow_tf32 - - exp_avg, exp_avg2, grad, proj_matrix, params = make_data(M, N, rank, dtype) - - if kernel == "torch": - ms, min_ms, max_ms = do_bench( - lambda: _ref_op( - grad, - proj_matrix, - exp_avg, - exp_avg2, - params, - ), - quantiles=quantiles, - ) - if kernel == "hybrid": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: _tt_hybrid( - grad, - proj_matrix, - exp_avg, - exp_avg2, - params, - store=True, - allow_tf32=allow_tf32, - fp8_fast_accum=fp8_fast_accum, - ), - quantiles=quantiles, - ) - if kernel == "fused": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: _tt_fused( - grad, - proj_matrix, - exp_avg, - exp_avg2, - params, - store=True, - allow_tf32=allow_tf32, - fp8_fast_accum=fp8_fast_accum, - ), - quantiles=quantiles, - ) - if kernel == "compiled": - compiled_op = torch.compile(_ref_op, fullgraph=True, mode="max-autotune") - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: compiled_op( - grad, - proj_matrix, - exp_avg, - exp_avg2, - params, - ), - quantiles=quantiles, - ) - - return ms, max_ms, min_ms - - return triton.testing.perf_report(config)(benchmark) diff --git a/test/galore/README.md b/test/galore/README.md deleted file mode 100644 index fc479267d8..0000000000 --- a/test/galore/README.md +++ /dev/null @@ -1,170 +0,0 @@ -### GaLore Memory Profiler - -Tests memory usage of `GaLore` optimizers. - -Uses `torch.profiler` under the hood with additional options for `nsys`, [`torch.cuda.memory`](https://pytorch.org/docs/stable/torch_cuda_memory.html) analyses. - -Runs an untrained Llama model with configs for various model sizes (see `configs`) from the original GaLore [repo](https://github.com/jiaweizzhao/GaLore/tree/master/configs) on a sample batch of data for a configurable set of iterations. - -The results of the profiler are saved and can be analyzed using the provided notebook. - -#### Examples - -Run memory profiler with `torch.optim.AdamW` - -``` -python galore_mem_prof.py -t --optimizer=adamw -``` - -Run profiler with `GaLoreAdamW` reference implementation with rank 128 - -``` -python galore_mem_prof.py -t --optimizer=galore_adamw --rank=128 -``` - -More options - -``` -python profile_memory_usage.py --help - -usage: profile_memory_usage.py [-h] [-t] [-m] [-ns] [--optimizer {adamw,galore_adamw}] [--rank RANK] [--update_proj_gap UPDATE_PROJ_GAP] - [--galore_scale GALORE_SCALE] [--wait_steps WAIT_STEPS] [--warmup_steps WARMUP_STEPS] [--profiler_steps PROFILER_STEPS] - [--max_steps MAX_STEPS] [--model_config MODEL_CONFIG] [--data_path DATA_PATH] [--output_dir OUTPUT_DIR] [-lr LEARNING_RATE] - [--weight_decay WEIGHT_DECAY] [--seed SEED] - -options: - -h, --help show this help message and exit - -t, --torch_profiler Enable torch profiler (default: False) - -m, --torch_memory_snapshot - Enable torch memory snapshot (default: False) - -ns, --nsys_profiler Enable nsys profiling context managerSurrounds training loop with cudaProfilerApi.{Start,Stop} (default: False) - --optimizer {adamw,galore_adamw} - Which optimizer to use (default: adamw) - --rank RANK - --update_proj_gap UPDATE_PROJ_GAP - --galore_scale GALORE_SCALE - --wait_steps WAIT_STEPS - Number of steps to run before starting torch profiler (default: 0) - --warmup_steps WARMUP_STEPS - Number of warmup steps for torch profiler (default: 0) - --profiler_steps PROFILER_STEPS - Number of active steps for torch profiler (default: 5) - --max_steps MAX_STEPS - Max number of train steps to run.Total train steps will be min of `max_steps` and the sum of torch profiler steps (`wait_steps` + - `warmup_steps` + `profiler_steps`). (default: 100) - --model_config MODEL_CONFIG - Path to Llama config file see `https://github.com/jiaweizzhao/GaLore/tree/master/configs` (default: ./configs/llama_100m.json) - --data_path DATA_PATH - Path to sample batch (default: ./data/sample_batch.pt) - --output_dir OUTPUT_DIR - Directory for profiler outputs (default: profiler_out) - -lr LEARNING_RATE, --learning_rate LEARNING_RATE - Learning rate (default: 0.001) - --weight_decay WEIGHT_DECAY - Weight decay for AdamW (default: 0.01) - --seed SEED Random seed for torch (default: 0) -``` - -#### Analysis - -After running the `profile_memory_usage`, the output directory (defaults to `profiler_out`) will have three types of files: - -- `*.{json,html} - these are the memory trace exports of `torch.profiler` - - the `html` contains the memory timeline plot - - the `json` file contains the raw data for this plot, which can be analyzed to extract summary stats. - - `galore_memory_analysis.py` along with `galore_memory_analysis_utils.py` demonstrate such analysis. -- `*.json.gz` - these are the complete `torch.profiler` traces which can be viewed using `perfetto`. - -#### Preliminary Observations - -- Memory Usage over Time - - - We can see a long delay between the first backwards step for `GaLoreAdamW` due to the calculation of the projection matrix (calls `torch.linalg.svd` on the `grad`). - - To visualize, paste the following into a jupyter notebook (replacing the filenames with the those after running the profiler script): - - ```python - adamW_html_trace = "./profiler_out/adamw_04-09-23.html" - adamW8bit_html_trace = "./profiler_out/adamw8bit_04-11-01.html" - galore_adamw_128_html_trace = "./profiler_out/galore_adamw-128-1.0-50_04-09-23.html" - galore_adamw8bit_128_html_trace = "./profiler_out/galore_adamw8bit-128-1.0-50_04-11-01.html" - - plot_memory_timeline(adamW_html_trace) - plot_memory_timeline(adamW8bit_html_trace) - plot_memory_timeline(galore_adamw_128_html_trace) - plot_memory_timeline(galore_adamw8bit_128_html_trace) - ``` - -- Memory Usage Stats - - - Summary stats for memory usage by type as well as total across all types can be viewed by running the following in jupyter notebook, again replacing the respective filepaths: - - ```python - adamW_trace = "./profiler_out/adamw_04-11-21-memory-timeline.json" - adamW8bit_trace = "./profiler_out/adamw8bit_04-11-21-memory-timeline.json" - galore_adamW_trace_128 = "./profiler_out/galore_adamw-128-1.0-50_04-11-21-memory-timeline.json" - galore_adamW8bit_trace_128 = "./profiler_out/galore_adamw8bit-128-1.0-50_04-11-21-memory-timeline.json" - - adamW_df = create_mem_df(adamW_trace, units="MB") - adamW8bit_df = create_mem_df(adamW8bit_trace, units="MB") - galore_adamW_df_128 = create_mem_df(galore_adamW_trace_128, units="MB") - galore_adamW8bit_df_128 = create_mem_df(galore_adamW8bit_trace_128, units="MB") - - show_memory_stats(adamW_df) - show_memory_stats(adamW8bit_df) - show_memory_stats(galore_adamW_df_128) - show_memory_stats(galore_adamW8bit_df_128) - ``` - - The following are results from sample runs of `Llama1B` model config with the following optimizers (all units in MB): - -- torch.optim.AdamW - - | | Parameter | Optimizer_State | Input | Temporary | Activation | Gradient | Autograd_Detail | Unknown | Total | - | ------ | --------- | --------------- | ----- | --------- | ---------- | -------- | --------------- | ------- | -------- | - | mean | 5,108.2 | 8,330.3 | 0.0 | 0.6 | 2,249.5 | 2,113.8 | 19.0 | 197.3 | 18,018.8 | - | min | 5,108.2 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 5,108.2 | - | median | 5,108.2 | 10,216.4 | 0.0 | 0.0 | 2,151.1 | 1,930.1 | 10.0 | 16.3 | 20,306.5 | - | max | 5,108.3 | 10,216.4 | 0.3 | 20.0 | 5,946.4 | 5,108.2 | 312.2 | 5,124.4 | 25,557.3 | - -- GaLoreAdamW reference, rank 128 - - | | Parameter | Optimizer_State | Input | Temporary | Activation | Gradient | Autograd_Detail | Unknown | Total | - | ------ | --------- | --------------- | ----- | --------- | ---------- | -------- | --------------- | ------- | -------- | - | mean | 7,298.0 | 1,348.4 | 0.0 | 0.7 | 1,455.6 | 3,183.6 | 12.2 | 31.3 | 13,330.0 | - | min | 5,108.2 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 5,108.2 | - | median | 7,796.2 | 1,576.7 | 0.0 | 0.0 | 545.4 | 3,898.2 | 0.0 | 26.2 | 14,422.8 | - | max | 8,047.2 | 1,576.7 | 0.3 | 42.7 | 5,960.0 | 5,108.2 | 312.2 | 518.2 | 15,349.2 | - -- bitsandbytes AdamW8bit - - | | Parameter | Optimizer_State | Input | Temporary | Activation | Gradient | Autograd_Detail | Unknown | Total | - | ------ | --------- | --------------- | ----- | --------- | ---------- | -------- | --------------- | ------- | -------- | - | mean | 5,108.2 | 2,047.4 | 0.0 | 0.7 | 2,390.0 | 1,925.2 | 20.1 | 20.3 | 11,511.9 | - | min | 5,108.2 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 5,108.2 | - | median | 5,108.2 | 2,560.4 | 0.0 | 0.0 | 2,351.0 | 1,738.1 | 10.0 | 16.3 | 12,621.3 | - | max | 5,108.3 | 2,560.4 | 0.3 | 20.0 | 5,946.4 | 5,108.2 | 312.2 | 46.9 | 13,631.3 | - -- GaLore AdamW8bit - - | | Parameter | Optimizer_State | Input | Temporary | Activation | Gradient | Autograd_Detail | Unknown | Total | - | ------ | --------- | --------------- | ----- | --------- | ---------- | -------- | --------------- | ------- | -------- | - | mean | 4,971.0 | 334.7 | 0.1 | 0.8 | 1,644.0 | 2,130.9 | 13.8 | 2,360.3 | 11,455.6 | - | min | 500.4 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 5,108.2 | - | median | 5,108.2 | 395.6 | 0.0 | 0.0 | 1,076.4 | 2,106.1 | 0.0 | 2,704.3 | 11,673.8 | - | max | 5,153.5 | 395.6 | 85.4 | 42.7 | 5,947.8 | 5,109.2 | 312.2 | 7,685.4 | 14,155.9 | - -- The `optimizer state` is indeed smaller for the `GaLoreAdamW` optimizer. -- Interestingly, the `Parameter` sizes balloons in the `GaLore` optimizer, likely due to extra data copies. Admittedly, the implementation is only a reference (per original repo) and leaves much room for optimization. -- The memory usage is in terms of memory allocated, which we can confirm by printing the max cuda memory allocated vs reserved (which the profiler script prints automatically). -- The `Total` column shows the allocation stats across all categories across all sampled timepoints. (Should not be interpreted as the row-wise sums). - -**NOTE**: The `json` output of the torch profiler memory trace is unlabeled. However, we can infer -- and confirm -- the labels by comparing the plots of the parsed dataframe with that of the direct `html` export of the profiler. - -- For example, after creating the dataframes per above, the following will plot the raw data, which should roughly reproduce the direct `html` export from `torch.profiler`, albeit with different timescale: - -```python -_ = adamW_df.plot(kind="area", stacked=True, ylabel="Memory (MB)" ) -_ = adamW8bit_df.plot(kind="area", stacked=True, ylabel="Memory (MB)" ) -_ = galore_adamW_df_128.plot(kind="area", stacked=True, ylabel="Memory (MB)" ) -_ = galore_adamW8bit_df_128.plot(kind="area", stacked=True, ylabel="Memory (MB)" ) -``` diff --git a/test/galore/profile_memory_usage.py b/test/galore/profile_memory_usage.py deleted file mode 100644 index 33fd746c39..0000000000 --- a/test/galore/profile_memory_usage.py +++ /dev/null @@ -1,297 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -import argparse -import contextlib -import logging -import os - -import model_configs -import profiling_utils -import torch -import torch.nn as nn -import torch.utils.data -from bitsandbytes.optim import AdamW8bit -from torch.profiler import record_function -from transformers import LlamaConfig, LlamaForCausalLM - -from torchao.prototype.galore.optim.galore_torch import AdamW as GaLoreAdamW -from torchao.prototype.galore.optim.galore_torch import AdamW8bit as GaLoreAdamW8bit - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def setup_galore(model, lr, weight_decay, rank, galore_scale, update_proj_gap): - galore_params = [] - target_modules_list = ["attn", "mlp"] - for module_name, module in model.named_modules(): - if not isinstance(module, nn.Linear): - continue - - if not any(target_key in module_name for target_key in target_modules_list): - continue - - logger.debug("Enabling GaLore for weights in module: ", module_name) - galore_params.append(module.weight) - id_galore_params = [id(p) for p in galore_params] - # make parameters without "rank" to another group - regular_params = [p for p in model.parameters() if id(p) not in id_galore_params] - # then call galore_adamw - - total_galore_params = sum(p.numel() for p in galore_params) - total_regular_params = sum(p.numel() for p in regular_params) - total_params = sum(p.numel() for p in model.parameters()) - assert total_galore_params + total_regular_params == total_params - - print( - f"Total params: {total_params} = GaLore params: {total_galore_params} + Regular params: {total_regular_params}" - ) - param_groups = [ - {"params": regular_params}, - { - "params": galore_params, - "rank": rank, - "update_proj_gap": update_proj_gap, - "scale": galore_scale, - "proj_type": "std", - }, - ] - if "adamw" in args.optimizer: - if "8bit" in args.optimizer: - optimizer = GaLoreAdamW8bit(param_groups, lr=lr, weight_decay=weight_decay) - else: - optimizer = GaLoreAdamW(param_groups, lr=lr, weight_decay=weight_decay) - else: - raise ValueError(f"Unknown optimizer: {args.optimizer}") - return optimizer - - -def train_step(model, batch, labels, optimizer, profiler=None): - with record_function("MODEL_FORWARD"): - loss = model(**batch, labels=labels).loss - - with record_function("MODEL_BACKWARD"): - loss.backward() - - with record_function("OPTIMIZER_STEP"): - optimizer.step() - optimizer.zero_grad(set_to_none=True) - - if profiler: - profiler.step() - - -def run(args, file_prefix): - torch.manual_seed(args.seed) - - # Initialize model from config dict - model_config = LlamaConfig() - try: - model_config_dict = getattr(model_configs, args.model_config.upper()) - except: - raise ValueError(f"Model config {args.model_config} not found") - model_config.update(model_config_dict) - model = LlamaForCausalLM(model_config).to("cuda") - - # Load sample batch - input_ids = torch.randint( - 0, - model_config.vocab_size, - size=(args.batch_size, args.max_seq_len), - dtype=torch.int64, - device="cuda", - ) - attention_mask = torch.ones_like(input_ids) - batch = dict(input_ids=input_ids, attention_mask=attention_mask) - labels = batch["input_ids"].clone() - - n_total_params = sum(p.numel() for p in model.parameters()) - trainable_params = [p for p in model.parameters() if p.requires_grad] - print( - f"Trainable params: {sum(p.numel() for p in trainable_params)} / {n_total_params}" - ) - - if args.optimizer.lower() == "adamw": - optimizer = torch.optim.AdamW( - trainable_params, lr=args.learning_rate, weight_decay=args.weight_decay - ) - - elif "galore" in args.optimizer.lower(): - optimizer = setup_galore( - model, - args.learning_rate, - args.weight_decay, - rank=args.rank, - galore_scale=args.galore_scale, - update_proj_gap=args.update_proj_gap, - ) - elif args.optimizer.lower() == "adamw8bit": - optimizer = AdamW8bit( - trainable_params, lr=args.learning_rate, weight_decay=args.weight_decay - ) - else: - raise "Unsupported optimizer" - - if args.torch_profiler: - prof_ctx = profiling_utils.get_torch_profiler( - name=file_prefix, - output_dir=args.output_dir, - wait_steps=args.wait_steps, - warmup_steps=args.warmup_steps, - active_steps=args.profiler_steps, - ) - elif args.nsys_profiler: - prof_ctx = profiling_utils.nsys_profiler() - else: - prof_ctx = contextlib.nullcontext() - - total_steps = min( - args.wait_steps + args.warmup_steps + args.profiler_steps, args.max_steps - ) - print( - f"Profiling {args.model_config} with {args.optimizer.upper()} for {total_steps} steps (wait_steps={args.wait_steps}, warmup_steps={args.warmup_steps}, profiler_steps={args.profiler_steps})" - ) - with prof_ctx as prof: - logger.debug(f"Profiler: {prof}") - for _ in range(total_steps): - with record_function("TRAIN_STEP"): - train_step( - model, - batch, - labels, - optimizer, - profiler=prof if args.torch_profiler else None, - ) - if args.torch_profiler: - print(f"Finished profiling, outputs saved to {args.output_dir}/{file_prefix}*") - else: - print("Finished profiling") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - parser.add_argument( - "-t", "--torch_profiler", action="store_true", help="Enable torch profiler" - ) - parser.add_argument( - "-m", - "--torch_memory_snapshot", - action="store_true", - help="Enable torch memory snapshot", - ) - - parser.add_argument( - "-ns", - "--nsys_profiler", - action="store_true", - help="Enable nsys profiling context manager" - "Surrounds training loop with cudaProfilerApi.{Start,Stop}", - ) - parser.add_argument( - "--optimizer", - default="adamw", - type=str, - choices=["adamw", "galore_adamw", "adamw8bit", "galore_adamw8bit"], - help="Which optimizer to use", - ) - parser.add_argument("--rank", type=int, default=128) - parser.add_argument("--update_proj_gap", type=int, default=50) - parser.add_argument("--galore_scale", type=float, default=1.0) - # parser.add_argument("--proj_type", type=str, default="std") - parser.add_argument( - "--wait_steps", - type=int, - default=0, - help="Number of steps to run before starting torch profiler", - ) - parser.add_argument( - "--warmup_steps", - type=int, - default=0, - help="Number of warmup steps for torch profiler", - ) - - parser.add_argument( - "--profiler_steps", - type=int, - default=5, - help="Number of active steps for torch profiler", - ) - parser.add_argument( - "--max_steps", - type=int, - default=100, - help="Max number of train steps to run." - "Total train steps will be min of `max_steps` and the sum of torch profiler steps (`wait_steps` + `warmup_steps` + `profiler_steps`).", - ) - parser.add_argument( - "--model_config", - default="llama100M", - type=str, - choices=["llama100M", "llama1B"], - help="Model configuration (see model_configs.py)", - ) - parser.add_argument( - "--batch_size", default=5, type=int, help="Batch size to use for train step" - ) - parser.add_argument( - "--max_seq_len", - default=256, - type=int, - help="Sequence length to use for train step, should be less than that in the specific model config", - ) - parser.add_argument( - "--output_dir", - default="profiler_out", - type=str, - help="Directory for profiler outputs", - ) - - parser.add_argument( - "-lr", - "--learning_rate", - default=1e-3, - type=float, - help="Learning rate", - ) - parser.add_argument( - "--weight_decay", - default=1e-2, - type=float, - help="Weight decay for AdamW", - ) - - parser.add_argument("--seed", default=0, type=int, help="Random seed for torch") - args = parser.parse_args() - output_dir = args.output_dir - # output_prefix = args.output_prefix - if not os.path.exists(output_dir): - os.makedirs(output_dir) - if "galore" not in args.optimizer.lower(): - file_prefix = args.optimizer.lower() - else: - file_prefix = "-".join( - [ - args.optimizer.lower(), - str(args.rank), - str(args.galore_scale), - str(args.update_proj_gap), - ] - ) - mem_ctx = ( - profiling_utils.memory_recorder( - file_name=os.path.join(output_dir, f"{file_prefix}-memory-snapshot") - ) - if args.torch_memory_snapshot - else contextlib.nullcontext() - ) - profiling_utils.flush_cuda_mem() - with mem_ctx: - run(args, file_prefix) - - profiling_utils.get_cuda_memory_usage(units="MB", show=True) diff --git a/test/kernel/galore_test_utils.py b/test/kernel/galore_test_utils.py deleted file mode 100644 index 2810941fe1..0000000000 --- a/test/kernel/galore_test_utils.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -import torch - -from torchao.prototype.galore.kernels.adam_downproj_fused import fused_adam_mm_launcher -from torchao.prototype.galore.kernels.adam_downproj_fused import ( - set_tuner_top_k as adam_downproj_tuner_topk, -) -from torchao.prototype.galore.kernels.adam_step import triton_adam_launcher -from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk -from torchao.prototype.galore.kernels.matmul import triton_mm_launcher -from torchao.prototype.galore.utils import TestGaLoreProjector as GaLoreProjector - -torch.manual_seed(0) - -adam_downproj_tuner_topk(10) -matmul_tuner_topk(10) - -BETA1 = 0.9 -BETA2 = 0.999 -EPS = 1e-8 -STEP_SIZE = 1e-4 - - -def make_data(M, N, rank, dtype): - grad = torch.randn(M, N, device="cuda", dtype=dtype) - params = torch.randn(M, N, device="cuda", dtype=dtype) - - galore_proj = GaLoreProjector(rank=rank) - galore_proj.update_orthogonal_matrix(grad) - - if M >= N: - exp_avg = torch.randn(M, rank, device="cuda", dtype=dtype) - else: - exp_avg = torch.randn(rank, N, device="cuda", dtype=dtype) - exp_avg2 = exp_avg**2 - - return exp_avg, exp_avg2, grad, galore_proj.ortho_matrix, params - - -def make_copy(*args): - return [t.detach().clone() for t in args] - - -def _ref_op( - grad, - proj_matrix, - exp_avg, - exp_avg2, - params, - beta1=BETA1, - beta2=BETA2, - eps=EPS, - step_size=STEP_SIZE, - **kwargs, -): - # Step 1: Down proj grad - M, N = grad.shape - if M >= N: - a, b = grad, proj_matrix.t() - else: - a, b = proj_matrix.t(), grad - low_rank_grad = a @ b - - # Step 2: update adam state - exp_avg.mul_(beta1).add_(low_rank_grad, alpha=(1.0 - beta1)) - exp_avg2.mul_(beta2).addcmul_(low_rank_grad, low_rank_grad, value=1.0 - beta2) - denom = exp_avg2.sqrt().add_(eps) - low_rank_norm_grad = exp_avg / denom - - # Step 3: project normalized low rank grad to full rank - if M >= N: - a, b = low_rank_norm_grad, proj_matrix - else: - a, b = proj_matrix, low_rank_norm_grad - full_grad_norm = a @ b - - # Finally, update params with updated grad - params.add_(full_grad_norm, alpha=-step_size) - - return exp_avg, exp_avg2, params - - -def _tt_hybrid( - grad, - proj_matrix, - exp_avg, - exp_avg2, - params, - store=True, - step_size=STEP_SIZE, - fp8_fast_accum=False, - allow_tf32=False, -): - M, N = grad.shape - if M >= N: - a, b = grad, proj_matrix.t() - else: - a, b = proj_matrix.t(), grad - low_rank_grad = a @ b - - exp_avg, exp_avg2, norm_grad = triton_adam_launcher( - exp_avg, exp_avg2, low_rank_grad, store=store - ) - - if M >= N: - a, b = low_rank_grad, proj_matrix - else: - a, b = proj_matrix, low_rank_grad - params = triton_mm_launcher( - a, - b, - epilogue_alpha=-step_size, - epilogue_source=params, - allow_tf32=allow_tf32, - fp8_fast_accum=fp8_fast_accum, - ) - return exp_avg, exp_avg2, params - - -def _tt_fused( - grad, - proj_matrix, - exp_avg, - exp_avg2, - params, - store=True, - step_size=STEP_SIZE, - fp8_fast_accum=False, - allow_tf32=False, -): - M, N = grad.shape - - if M >= N: - a, b = grad, proj_matrix.t() - else: - a, b = proj_matrix.t(), grad - exp_avg, exp_avg2, low_rank_grad = fused_adam_mm_launcher( - a, - b, - exp_avg=exp_avg, - exp_avg2=exp_avg2, - store=store, - fp8_fast_accum=fp8_fast_accum, - allow_tf32=allow_tf32, - ) - - if M >= N: - a, b = low_rank_grad, proj_matrix - else: - a, b = proj_matrix, low_rank_grad - params = triton_mm_launcher( - a, - b, - epilogue_alpha=-step_size, - epilogue_source=params, - allow_tf32=allow_tf32, - fp8_fast_accum=fp8_fast_accum, - ) - return exp_avg, exp_avg2, params - - # logging.basicConfig(level=logging.INFO) - - -def get_kernel(kernel): - if kernel == "ref": - op = _ref_op - elif kernel == "ref": - op = torch.compile(_ref_op, fullgraph=True, mode="max-autotune") - elif kernel == "hybrid": - op = _tt_hybrid - elif kernel == "fused": - op = _tt_fused - else: - raise ValueError(f"Unknown kernel {kernel}") - - return lambda *args, **kwargs: op(*args, **kwargs) diff --git a/test/kernel/test_fused_kernels.py b/test/kernel/test_fused_kernels.py deleted file mode 100644 index 3c51b78f1b..0000000000 --- a/test/kernel/test_fused_kernels.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -import itertools - -import pytest - -# Skip entire test if triton is not available, otherwise CI failure -try: - import triton # noqa: F401 -except ImportError: - pytest.skip("triton is not installed", allow_module_level=True) - -import torch -from galore_test_utils import get_kernel, make_copy, make_data - -from torchao.testing.utils import skip_if_rocm - -torch.manual_seed(0) -MAX_DIFF_no_tf32 = 1e-5 -MAX_DIFF_tf32 = 1e-3 - - -def run_test(kernel, exp_avg, exp_avg2, grad, proj_matrix, params, allow_tf32): - # Copy to use for first run -- needed because of autotuning and inplace ops - ( - exp_avg_autotune_copy, - exp_avg2_autotune_copy, - grad_autotune_copy, - proj_matrix_autotune_copy, - params_autotune_copy, - ) = make_copy(exp_avg, exp_avg2, grad, proj_matrix, params) - - # Copy to use for second run to check accuracy - ( - exp_avg_test_copy, - exp_avg2_test_copy, - grad_test_copy, - proj_matrix_test_copy, - params_test_copy, - ) = make_copy(exp_avg, exp_avg2, grad, proj_matrix, params) - - print( - f"Running with {grad.shape[0]} x {grad.shape[1]} grad (param) shape, GaLore orthogonal matrix {list(proj_matrix.shape)}, dtype {grad.dtype} and allow_tf32 {allow_tf32}\n" - f"Kernel: {kernel}", - flush=True, - ) - - ref_op = get_kernel("ref") - test_op = get_kernel(kernel) - - # Reference run - ref_out = ref_op( - grad, - proj_matrix, - exp_avg, - exp_avg2, - params, - ) - - # Autotune - _ = test_op( - grad_autotune_copy, - proj_matrix_autotune_copy, - exp_avg_autotune_copy, - exp_avg2_autotune_copy, - params_autotune_copy, - store=False, - allow_tf32=allow_tf32, - ) - - # Accuracy run - test_out = test_op( - grad_test_copy, - proj_matrix_test_copy, - exp_avg_test_copy, - exp_avg2_test_copy, - params_test_copy, - store=True, - allow_tf32=allow_tf32, - ) - print("Accuracy:") - - output_names = [ - "adam state - running grad mean", - "adam state - running grad var", - "params (after update)", - ] - MAX_DIFF = MAX_DIFF_tf32 if allow_tf32 else MAX_DIFF_no_tf32 - for name, ref, tt in zip(output_names, ref_out, test_out): - max_diff = (ref - tt).abs().max() - print(f"-> {name}:\n Max err: {max_diff:.6f}") - assert max_diff < MAX_DIFF - - -KERNELS = ["hybrid"] # "fused"] -DTYPES = [torch.float32] # torch.float16 -ROW_DIMS = [4096] -COL_DIMS = [4096] # , 11008] -RANKS = [128] -ALLOW_TF32 = [False] # , True] - -TEST_CONFIGS = list( - itertools.product(KERNELS, DTYPES, ROW_DIMS, COL_DIMS, RANKS, ALLOW_TF32) -) - -# TEST_CONFIGS = TEST_CONFIGS[0:1] - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") -@pytest.mark.parametrize("kernel, dtype, M, N, rank, allow_tf32", TEST_CONFIGS) -@skip_if_rocm("ROCm enablement in progress") -def test_galore_fused_kernels(kernel, dtype, M, N, rank, allow_tf32): - torch.backends.cuda.matmul.allow_tf32 = allow_tf32 - - exp_avg, exp_avg2, grad, proj_matrix, params = make_data(M, N, rank, dtype) - run_test(kernel, exp_avg, exp_avg2, grad, proj_matrix, params, allow_tf32) diff --git a/test/kernel/test_galore_downproj.py b/test/kernel/test_galore_downproj.py deleted file mode 100644 index f0e135667e..0000000000 --- a/test/kernel/test_galore_downproj.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -import pytest - -# Skip entire test if triton is not available, otherwise CI failure -try: - import triton # noqa: F401 -except ImportError: - pytest.skip("triton is not installed", allow_module_level=True) - -import torch -from galore_test_utils import make_data - -from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk -from torchao.prototype.galore.kernels.matmul import triton_mm_launcher -from torchao.testing.utils import skip_if_rocm - -torch.manual_seed(0) - -matmul_tuner_topk(10) -MAX_DIFF_no_tf32 = 1e-4 -MAX_DIFF_tf32 = 1e-2 - - -TEST_CONFIGS = [ - # (4096, 4096, 128, True, False, torch.float32), - (4096, 4096, 128, False, False, torch.float32), - # (4096, 11008, 128, True, False, torch.float32), - (4096, 11008, 128, False, False, torch.float32), -] - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") -@pytest.mark.parametrize("M, N, rank, allow_tf32, fp8_fast_accum, dtype", TEST_CONFIGS) -@skip_if_rocm("ROCm enablement in progress") -def test_galore_downproj(M, N, rank, allow_tf32, fp8_fast_accum, dtype): - torch.backends.cuda.matmul.allow_tf32 = allow_tf32 - MAX_DIFF = MAX_DIFF_tf32 if allow_tf32 else MAX_DIFF_no_tf32 - exp_avg, exp_avg2, grad, galore_proj, params = make_data(M, N, rank, dtype) - - if M >= N: - a, b = grad, galore_proj.t() - else: - a, b = galore_proj.t(), grad - low_rank_ref = lambda: a @ b - low_rank_tt = lambda: triton_mm_launcher( - a, b, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum - ) - diff = torch.max(torch.abs(low_rank_ref() - low_rank_tt())) - if not diff < MAX_DIFF: - print("diff: ", torch.max(torch.abs(low_rank_ref() - low_rank_tt()))) - assert diff < MAX_DIFF diff --git a/test/quantization/test_galore_quant.py b/test/quantization/test_galore_quant.py deleted file mode 100644 index cb2902d00f..0000000000 --- a/test/quantization/test_galore_quant.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -import itertools - -import pytest - -from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 - -# Skip entire test if triton is not available, otherwise CI failure -try: # noqa: F401 - import triton # noqa: F401 -except ImportError: # noqa: F401 - pytest.skip("triton is not installed", allow_module_level=True) # noqa: F401 -import torch - -# Skip entire test if CUDA is not available or ROCM is enabled -if not torch.cuda.is_available() or torch.version.hip is not None: - pytest.skip( - "CUDA is not available/ ROCM support is under development", - allow_module_level=True, - ) - -from bitsandbytes.functional import ( - create_dynamic_map, - dequantize_blockwise, - quantize_blockwise, -) - -from torchao.prototype.galore.kernels import ( - triton_dequant_blockwise, - triton_quantize_blockwise, -) -from torchao.testing.utils import skip_if_rocm - -SEED = 0 -torch.manual_seed(SEED) - -DIM1 = [64, 1024, 4096] -DIM2 = [1024, 2048, 4096] -SIGNS = [True, False] -DTYPES = [torch.float32] # , torch.float16] -BLOCKSIZE = [2048] - -TEST_CONFIGS = list(itertools.product(DIM1, DIM2, DTYPES, SIGNS, BLOCKSIZE)) - - -@pytest.mark.skip("skipping for now, see comments below") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") -@pytest.mark.parametrize( - "dim1,dim2,dtype,signed,blocksize", - TEST_CONFIGS, -) -def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize): - g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01 - - qmap = create_dynamic_map(signed).to(g.device) - - ref_bnb, qstate = quantize_blockwise(g, code=qmap, blocksize=blocksize) - bnb_norm = (g.reshape(-1, blocksize) / qstate.absmax[:, None]).reshape(g.shape) - - tt_q, tt_norm, tt_absmax = triton_quantize_blockwise( - g, qmap, group_size=blocksize, return_normalized=True - ) - tt_check = torch.allclose(ref_bnb, tt_q) - - # see notes.md under `prototype.galore.kernels` for an explanation of the following conditions - if not tt_check: - print( - f"Failed quantization check for {dim1} x {dim2}, {dtype}, signed {signed}" - ) - print(f"Absmax: {(qstate.absmax - tt_absmax).abs().max()}") - print(f"Norm diff: {(bnb_norm - tt_norm).abs().max()}") - - idx_diff = (ref_bnb != tt_q).to("cuda") - print(f"Num code idx diffs: {idx_diff.sum()}") - max_idx_diff = (ref_bnb - tt_q).abs().max() - print(f"Max code idx diff: {max_idx_diff}") - - # This below checks that the value being quantized falls half-way between two code buckets - # where bitsandbytes assigns to one and the triton implementation assigns to the other - # Since either bucket is technically valid, we only check that the distance between the value and the - # adjacent buckets are the same. I.e., we don't require that the triton implementation exactly matches - # bitsandbytes. - - bnb_code = qmap[ref_bnb[idx_diff].tolist()] - tt_code = qmap[tt_q[idx_diff].tolist()] - bnb_dist = torch.abs(bnb_code - bnb_norm[idx_diff]) - torch_dist = torch.abs(tt_code - bnb_norm[idx_diff]) - - dist_sum = torch.sum(bnb_dist - torch_dist) - print(f"Distance sum: {torch.sum(bnb_dist - torch_dist)}") - assert tt_check or (not tt_check and dist_sum < 1e-4) - - -@pytest.mark.parametrize( - "dim1,dim2,dtype,signed,blocksize", - TEST_CONFIGS, -) -@skip_if_rocm("ROCm enablement in progress") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") -@pytest.mark.skipif( - TORCH_VERSION_AT_LEAST_2_7, reason="Failing in CI" -) # TODO: fix this -def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize): - g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01 - - qmap = create_dynamic_map(signed).to(g.device) - - q, qstate = quantize_blockwise(g, code=qmap, blocksize=blocksize) - - dq_ref = dequantize_blockwise(q, qstate) - dq = triton_dequant_blockwise(q, qmap, qstate.absmax, group_size=blocksize) - assert torch.allclose(dq, dq_ref) diff --git a/torchao/prototype/README.md b/torchao/prototype/README.md index 70f9d87537..257ba4ffb8 100644 --- a/torchao/prototype/README.md +++ b/torchao/prototype/README.md @@ -6,9 +6,6 @@ #### Code structure -- `galore` - fused kernels for memory-efficient pre-training / fine-tuning per the [GaLore algorithm](https://arxiv.org/abs/2403.03507) - - `galore/kernels` - `triton` kernels that fuse various steps of the `GaLore` algorithm - - `galore/docs` - implementation notes and discussion of issues faced in kernel design. - [`quant_llm`](quant_llm) - FP16 x Floatx mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112) - ~~`low_bit_optim`~~ - re-implementation of 8-bit optimizers from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and 4-bit optimizers from [lpmm](https://github.com/thu-ml/low-bit-optimizers). **Promoted to `torchao.optim`.** - [`spinquant`](spinquant) - re-implementation of [SpinQuant](https://arxiv.org/abs/2405.16406) diff --git a/torchao/prototype/galore/README.md b/torchao/prototype/galore/README.md deleted file mode 100644 index 2a7ae1f7d9..0000000000 --- a/torchao/prototype/galore/README.md +++ /dev/null @@ -1,11 +0,0 @@ -## Fused GaLore - -### Experimental kernels for fusing various parts of the GaLore algorithm - -#### AdamW - -See `docs/galore_adam.md` for implementation notes. - -#### AdamW8bit - -See `docs/galore_adam8bit.md` for implementation notes. diff --git a/torchao/prototype/galore/docs/README.md b/torchao/prototype/galore/docs/README.md deleted file mode 100644 index 74b077c4a9..0000000000 --- a/torchao/prototype/galore/docs/README.md +++ /dev/null @@ -1,198 +0,0 @@ -## Fused GaLore Adam (WIP) - -### Various fused implementations of `Adam` update step per [Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507) - -This is an initial attempt at optimizing the update step of the `GaLore Adam` optimizer. - -#### Overview - -The `GaLore` `Adam` optimizer introduces additional ops to the traditional `adam` update step. - -Specifically: - -1. `grad` is projected to low rank --> additional matmul -2. `adam` states are updated with `grad` elementwise (same as `Adam` except in low-rank) -3. normalized `grad` is projected to full rank --> additional matmul -4. `params` are updated with the normalized full rank grad - -#### Implementation - -Various fusions were attempted across 2 kernel implementations: - -- `Fused` - - Steps 1 & 2 are fused: the `adam` state updates are loaded and updated (inplace) during the first `matmul` - - Steps 3 & 4 are fused: the param update is folded as an epilogue into the second `matmul` -- `Hybrid` - - Step 1 is performed using standard `torch matmul` (i.e., `cuBlas`) - - Step 2 is fused as an elementwise kernel - - Steps 3 & 4 per `Fused` - -#### Performance - -Below are benchmarks for various kernels: - -- `torch` - reference `torch` implementation where each of the steps are implemented verbatim per above -- `hybrid` - see above -- `fused` - see above -- `compiled` - `torch` reference implementation compiled using `torch.compile` with `fullgraph=True` and `mode="max-autotune"`. - -Configs for each benchmark are the `grad (param)` shape, `dtype` of `grad` and `adam` states, and `allow_tf32`, whether `torch` and `triton` matmuls are allowed to use `TF32` tensor cores (see `Discussion`). - -`Grad shape`: `4096x4096`, `dtype`: `torch.float32`, `allow_tf32`: `False` - -``` -Median times (ms): - rank torch hybrid fused compiled -0 32.0 0.560128 0.347136 0.505856 0.534528 -1 64.0 0.627712 0.404480 0.600960 0.615424 -2 128.0 0.825232 0.583168 0.985072 0.833536 -3 256.0 1.378304 1.126400 1.489920 1.375232 -4 512.0 2.286080 2.101760 2.969600 2.302976 -``` - -`Grad shape`: `4096x4096`, `dtype`: `torch.float32`, `allow_tf32`: `True` - -``` -Median times (ms): - rank torch hybrid fused compiled -0 32.0 0.540672 0.321536 0.316416 0.508928 -1 64.0 0.612240 0.337728 0.345024 0.538624 -2 128.0 0.640000 0.395264 0.393216 0.693248 -3 256.0 0.777216 0.489472 0.548784 1.102848 -4 512.0 1.216512 0.864256 0.960512 1.968128 -``` - -`Grad shape`: `4096x11008`, `dtype`: `torch.float32`, `allow_tf32`: `False` - -``` -Median times (ms): - rank torch hybrid fused compiled -0 32.0 1.538672 0.915456 0.835584 1.364032 -1 64.0 1.546240 0.940032 1.022976 1.486848 -2 128.0 2.116608 1.498112 1.613312 2.098176 -3 256.0 3.423744 2.719744 2.881536 3.227136 -4 512.0 5.499904 5.036544 5.450752 5.508096 -``` - -`Grad shape`: `4096x11008`, `dtype`: `torch.float32`, `allow_tf32`: `True` - -``` -Median times (ms): - rank torch hybrid fused compiled -0 32.0 1.413120 0.871424 0.817152 1.353184 -1 64.0 1.489920 0.916480 0.854016 1.389568 -2 128.0 1.679360 0.996352 1.005568 1.563648 -3 256.0 2.152448 1.415168 1.470464 2.185216 -4 512.0 3.210240 2.460672 2.580480 3.477504 -``` - -##### Accuracy - -Comparison to reference `torch` implementation: - -``` -Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32, and allow_tf32 True -Kernel: hybrid -Accuracy: --> adam state - running grad mean: - Max err: 0.000000 Relative err: 0.000001 --> adam state - running grad var: - Max err: 0.000002 Relative err: 0.000002 --> params (after update): - Max err: 0.000000 Relative err: 0.000001 -``` - -``` -Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32 and allow_tf32 False -Kernel: hybrid -Accuracy: --> adam state - running grad mean: - Max err: 0.000000 Relative err: 0.000000 --> adam state - running grad var: - Max err: 0.000002 Relative err: 0.000002 --> params (after update): - Max err: 0.000000 Relative err: 0.000000 -``` - -``` -Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32 and allow_tf32 True -Kernel: fused -Accuracy: --> adam state - running grad mean: - Max err: 0.000845 Relative err: 0.001152 --> adam state - running grad var: - Max err: 0.000162 Relative err: 0.000161 --> params (after update): - Max err: 0.000000 Relative err: 0.000001 -``` - -``` -Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32 and allow_tf32 False -Kernel: fused -Accuracy: --> adam state - running grad mean: -Max err: 0.000003 Relative err: 0.000004 --> adam state - running grad var: -Max err: 0.000002 Relative err: 0.000002 --> params (after update): -Max err: 0.000000 Relative err: 0.000000 -``` - -#### Discussion - -##### Down Projection GEMM Shape - -The motivation for the `hybrid` approach is the unconventional matrix shapes of the down projection (Step 1): - -- The projection is always done such that the larger dimension of the `grad` matrix is maintained while other is projected to low rank per the `GaLore` algorithm - - E.g., if `M >= N`, the GEMM is of shape (`M x N`) x (`N x rank`) = (`M x rank`), (`rank x M`) x (`M x N`) = (`rank x N`) otherwise -- Since `{M, N} >> rank` by definition, this results in a large reduction dimension relative to one of the output dimensions (output matrix is either fat or skinny) -- This does not fit cleanly into the `split-k / parallel reduction` `GEMM` paradigm which is more tailored for shapes where both output dims are smaller than the reduction dimension. -- Consequently, I had trouble finding an optimal kernel config using `triton` `autotuner` for the down projection step, despite tuning across many compute and io-bound configs (see `fused.triton_utils.kernels.matmul.py`). -- Benchmarking `triton`-tuned `matmul` against default `torch.matmul` for these shapes showed worse performance, for `torch.float32` - -#### Effect of `TF32` tensor cores - -`allow_tf32`: this has significant impact on relative performance of `triton` vs `torch` matmuls: - -- Quick benchmarks of the downprojection `matmul` show that: - - with `allow_tf32=True` for both, triton exhibits `~1.30x` performance improvement over `torch`. - - with `allow_tf32=False`, performance of `triton` degrades significantly to `~.50x` of `torch`. - -See this [`torch note`](https://pytorch.org/docs/stable/notes/cuda.html#tf32-on-ampere) for more details on this feature. - -**Note**: This might be less of a concern given this incoming triton [PR](https://github.com/openai/triton/pull/3234), which implements a fast `TF32` trick that improves both performance and accuracy. - -#### Repro - -_Accuracy_ - -- Test accuracy of `torch` vs `hybrid` for `M=4096`, `N=4096`, `rank=128`, and `tf32` switched on: - - ```python - pytest test/kernel/test_fused_kernels.py - ``` - -_Benchmark_ - -- Benchmark across all kernels without `tf32`: - - ```python - python benchmarks/bench_galore_fused_kernels.py - ``` - -For additional benchmarking options: - -```python - python benchmarks/bench_galore_fused_kernels.py --help -``` - -#### Test Env - -- GPU Device Props: - - Name: `NVIDIA RTX A6000` - - CC: `86` - - Total_memory: `48676MB` - - SM count: `84` -- Torch: `2.2.2` -- Triton: `2.2.0` diff --git a/torchao/prototype/galore/docs/galore_adam8bit.md b/torchao/prototype/galore/docs/galore_adam8bit.md deleted file mode 100644 index ddb45c29b8..0000000000 --- a/torchao/prototype/galore/docs/galore_adam8bit.md +++ /dev/null @@ -1,35 +0,0 @@ -## GaLore AdamW8bit Optimizer - -### Overview - -`GaLore` AdamW8bit optimizer utilizes `bitsandbytes` `AdamW8bit` optimizer to additionally quantize the optimizer states. - -In addition to the additional ops introduced by `GaLore` to the standard `Adam` update step (see the `galore_adam.md` for details), additional dequantize / quantize steps are needed: - -- one to to dequantize the quantized states for the state update -- after the states are updated, they need to quantized along and `quant_state` updated -- For `bitsandbytes` `AdamW8bit`, the `quant_state` consists of group-wise (`blocksize`) scaling factors. - -The `bitsandbytes` 8bit optimizer is implemented in CUDA, with handcrafted logic for implementing each of these steps. - -> The motivation for re-implementing this optimizer purely in `triton` / `torch` is to enable exploration of various fusion / optimization strategies that would be difficult with the current CUDA impl. - -#### Quantization Algorithm - -1. Weights are quantized in contiguous `blocksize` segments -2. Given tensor `M x N`, reshape to `-1 x blocksize` -3. Find columnwise `absmax` and normalize tensor by dividing by `absmax` -4. Reshape normalized tensor back to original shape -5. `bitsandbytes` then uses an `8-bit` [quantization code](https://github.com/TimDettmers/bitsandbytes/blob/76885a41df9e6c94b3f80b1c37374c8441b6933e/bitsandbytes/optim/optimizer.py#L146-L151), which can either be signed or unsigned -- signed for tracking `mean`, unsigned for tracking `var`. -6. The normalized tensor is then assigned to the code it is closest to: - - E.g., given normalized value `.0412` and buckets `.0402` and `.0416`, it will be assigned the latter code. -7. **IMPORTANT**: This gives rise to a small number of edge-case errors when trying to reproduce `bitsandbytes` quantization - - Specifically, if a normalized value falls directly between two codes there is a degree of indeterminism. - - E.g., in the previous example, if the normalized value is `.0409`, it would be equidistant to the codes `.0402` and `.0416`. - - See the assertions in the `test_galore_quant.py` unittest that checks that these are the only discrepancies arising from the triton implementation (run with `pytest -sv -k` flags to see the output from this test). - -### bitsandbytes CUDA Source - -- Adam[W]8bit [update step](https://github.com/TimDettmers/bitsandbytes/blob/fd9d072e02b74348004f197e686e168448883a9e/csrc/kernels.cu#L1770) -- Adam blockwise [quantization](https://github.com/TimDettmers/bitsandbytes/blob/fd9d072e02b74348004f197e686e168448883a9e/csrc/kernels.cu#L413) after update -- [Blockwise](https://github.com/TimDettmers/bitsandbytes/blob/fd9d072e02b74348004f197e686e168448883a9e/csrc/kernels.cu#L726) [Quantization](https://github.com/TimDettmers/bitsandbytes/blob/fd9d072e02b74348004f197e686e168448883a9e/csrc/kernels.cu#L339) kernel diff --git a/torchao/prototype/galore/kernels/adam_downproj_fused.py b/torchao/prototype/galore/kernels/adam_downproj_fused.py deleted file mode 100644 index c45fc5d238..0000000000 --- a/torchao/prototype/galore/kernels/adam_downproj_fused.py +++ /dev/null @@ -1,365 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -import logging - -import torch -import triton -import triton.language as tl - -from torchao.prototype.common.triton.matmul_perf_model import ( - early_config_prune, - estimate_matmul_time, -) - -from .adam_step import BETA1, BETA2, EPS -from .custom_autotune import Config, autotune -from .matmul import ( - TRITON_ACC_TYPES, - get_higher_dtype, - get_mm_heuristics, - init_to_zero, - to_tl_type, -) -from .matmul import get_autotuner as default_mm_autotuner - -logger = logging.getLogger(__name__) - -AUTOTUNER_TOP_K = 50 - - -def set_tuner_top_k(k): - global AUTOTUNER_TOP_K - AUTOTUNER_TOP_K = k - - -@triton.jit -def _fused_adam_mm_kernel( - # matmul args - A, - B, - C, - M, - N, - K, # - stride_am, - stride_ak, # - stride_bk, - stride_bn, # - stride_cm, - stride_cn, # - # adam epilogue, - exp_avg_ptr, # these will be updated inplace - exp_avg2_ptr, - store, - # grad_ptr, # low rank grad output -- not needed, C is the output - # meta params - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, # - SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, - GROUP_M: tl.constexpr, - # Adam-specific params - BETA1: tl.constexpr = BETA1, - BETA2: tl.constexpr = BETA2, - EPS: tl.constexpr = EPS, - # matmul kernel settings - acc_dtype: tl.constexpr = tl.float32, # - allow_tf32: tl.constexpr = False, # higher precision for this phase - fp8_fast_accum: tl.constexpr = False, # - AB_DTYPE: tl.constexpr = None, # -): - # matrix multiplication - pid = tl.program_id(0) - pid_z = tl.program_id(1) - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) - # pointers - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) - for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - k_remaining = K - k * (BLOCK_K * SPLIT_K) - _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) - a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) - b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) - if AB_DTYPE is not None: - a = a.to(AB_DTYPE) - b = b.to(AB_DTYPE) - if fp8_fast_accum: - acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32) - else: - acc += tl.dot(a, b, out_dtype=acc_dtype, allow_tf32=allow_tf32) - A += BLOCK_K * SPLIT_K * stride_ak - B += BLOCK_K * SPLIT_K * stride_bk - # acc = acc.to(C.dtype.element_ty) - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - epilogue_offsets = rm[:, None] * stride_cm + rn[None, :] * stride_cn - mask = (rm < M)[:, None] & (rn < N)[None, :] - - # Load adam state - exp_avg = tl.load(exp_avg_ptr + epilogue_offsets, mask=mask) - exp_avg2 = tl.load(exp_avg2_ptr + epilogue_offsets, mask=mask) - - # Perform update - exp_avg = BETA1 * exp_avg.to(acc.dtype) + (1.0 - BETA1) * acc - exp_avg2 = BETA2 * exp_avg2.to(acc.dtype) + (1.0 - BETA2) * (acc * acc) - denom = tl.sqrt(exp_avg2) + EPS - norm_grad = exp_avg / denom - # Convert to output type - norm_grad = norm_grad.to(C.dtype.element_ty) - - # acc = acc.to(C.dtype.element_ty) - C = C + epilogue_offsets - - # handles write-back with reduction-splitting - if SPLIT_K == 1: - tl.store(C, norm_grad, mask=mask) - else: - tl.atomic_add(C, norm_grad, mask=mask) - - if store: - tl.store( - exp_avg_ptr + epilogue_offsets, - exp_avg, - mask=mask, - ) - tl.store( - exp_avg2_ptr + epilogue_offsets, - exp_avg2, - mask=mask, - ) - - -def _get_configs_splitk_all(): - """ - Configs specific to split-k matmuls - Not used currently - """ - configs = [] - for num_stages in [2, 3, 4, 5]: - for block_m in [16, 32, 64, 128]: - for block_k in [16, 32, 64, 128, 256]: - for block_n in [16, 32, 64, 128]: - num_warps = 2 if block_n <= 64 else 4 - configs.append( - Config( - { - "BLOCK_M": block_m, - "BLOCK_N": block_n, - "BLOCK_K": block_k, - "SPLIT_K": 1, - }, - num_stages=num_stages, - num_warps=num_warps, - ) - ) - # split_k - for split_k in [2, 4, 8]: - configs.append( - Config( - { - "BLOCK_M": block_m, - "BLOCK_N": block_n, - "BLOCK_K": block_k, - "SPLIT_K": split_k, - }, - num_stages=num_stages, - num_warps=num_warps, - pre_hook=init_to_zero("C"), - ) - ) - return configs - - -def _get_configs_splitk_small(): - """Configs for split-k, smaller version than above - Not used currently - """ - configs = [] - for num_stages in [2, 3, 4]: - for block_m in [64, 128]: - for block_k in [16, 32, 64]: - for block_n in [64, 128]: - num_warps = 2 if block_n <= 64 else 4 - configs.append( - Config( - { - "BLOCK_M": block_m, - "BLOCK_N": block_n, - "BLOCK_K": block_k, - "SPLIT_K": 1, - }, - num_stages=num_stages, - num_warps=num_warps, - ) - ) - # split_k - for split_k in [2, 4, 8]: - configs.append( - Config( - { - "BLOCK_M": block_m, - "BLOCK_N": block_n, - "BLOCK_K": block_k, - "SPLIT_K": split_k, - }, - num_stages=num_stages, - num_warps=num_warps, - pre_hook=init_to_zero("C"), - ) - ) - return configs - - -def _splitk_autotuner( - configs=_get_configs_splitk_small(), - key=["M", "N", "K"], - early_config_prune=early_config_prune, - perf_model=estimate_matmul_time, - top_k=AUTOTUNER_TOP_K, -): - """Autotuner for splitk matmuls - Not used currently - """ - autotuner = autotune( - configs=configs, - key=key, - prune_configs_by={ - "early_config_prune": early_config_prune, - "perf_model": perf_model, - "top_k": top_k, - }, - ) - - return autotuner - - -def _get_kernel( - tuner_fn=default_mm_autotuner, heuristics_fn=get_mm_heuristics, topk=AUTOTUNER_TOP_K -): - tuner = tuner_fn() - tuner.topk = topk - heuristics = heuristics_fn() - return tuner(heuristics(_fused_adam_mm_kernel)) - - -DEFAULT_KERNEL = _get_kernel() - - -def fused_adam_mm_launcher( - a, - b, - *, - exp_avg, - exp_avg2, - store=True, - BETA1=BETA1, - BETA2=BETA2, - EPS=EPS, - allow_tf32=False, - fp8_fast_accum=False, - acc_dtype=None, - output_dtype=None, - kernel=None, -): - device = a.device - # handle non-contiguous inputs if necessary - # a = grad - # b = galore_proj.ortho_matrix.t() - if a.stride(0) > 1 and a.stride(1) > 1: - a = a.contiguous() - if b.stride(0) > 1 and b.stride(1) > 1: - b = b.contiguous() - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - M, K = a.shape - _, N = b.shape - - # common type between a and b - ab_dtype = get_higher_dtype(a.dtype, b.dtype) - - # allocates output - if output_dtype is None: - output_dtype = ab_dtype - - c = torch.empty((M, N), device=device, dtype=output_dtype) - - if acc_dtype is None: - acc_dtype = [ab_dtype][0] - else: - assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype" - assert acc_dtype in TRITON_ACC_TYPES[a.dtype], ( - "acc_dtype not compatible with the type of a" - ) - assert acc_dtype in TRITON_ACC_TYPES[b.dtype], ( - "acc_dtype not compatible with the type of b" - ) - - acc_dtype = to_tl_type(acc_dtype) - ab_dtype = to_tl_type(ab_dtype) - output_dtype = to_tl_type(output_dtype) - - # Tensor cores support input with mixed float8 types. - if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [ - tl.float8e4nv, - tl.float8e5, - ]: - ab_dtype = None - - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), - META["SPLIT_K"], - ) - - if kernel is None: - kernel = DEFAULT_KERNEL - kernel[grid]( - a, - b, - c, - M, - N, - K, # - a.stride(0), - a.stride(1), # - b.stride(0), - b.stride(1), # - c.stride(0), - c.stride(1), # - exp_avg, - exp_avg2, - store=store, - BETA1=BETA1, # , # - BETA2=BETA2, # , # - EPS=EPS, # - acc_dtype=acc_dtype, # - allow_tf32=allow_tf32, # - fp8_fast_accum=fp8_fast_accum, # - GROUP_M=8, - AB_DTYPE=ab_dtype, - ) - return exp_avg, exp_avg2, c # c -> normalized low rank grad diff --git a/torchao/prototype/galore/kernels/matmul.py b/torchao/prototype/galore/kernels/matmul.py deleted file mode 100644 index 0a7c830f02..0000000000 --- a/torchao/prototype/galore/kernels/matmul.py +++ /dev/null @@ -1,417 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -import torch -import triton -import triton.language as tl - -from torchao.prototype.common.triton.matmul_perf_model import ( - early_config_prune, - estimate_matmul_time, -) - -from .custom_autotune import Config, autotune, heuristics - -# Allowed types for acc_type given the types of a and b. -TRITON_ACC_TYPES = { - torch.float16: (torch.float32, torch.float16), - torch.bfloat16: (torch.float32, torch.bfloat16), - torch.float32: (torch.float32,), - torch.int8: (torch.int32,), -} - -AUTOTUNER_TOP_K = 50 -_ordered_datatypes = [torch.int8, torch.float16, torch.bfloat16, torch.float32] - - -def upcast_if_fp8(a): - if "fp8" in str(a): - return torch.float16 - return a - - -def get_higher_dtype(a, b): - a = upcast_if_fp8(a) - b = upcast_if_fp8(b) - if a is b: - return a - - assert a in _ordered_datatypes - assert b in _ordered_datatypes - - for d in _ordered_datatypes: - if a is d: - return b - if b is d: - return a - - -def init_to_zero(name): - return lambda nargs: nargs[name].zero_() - - -def set_tuner_top_k(k): - global AUTOTUNER_TOP_K - AUTOTUNER_TOP_K = k - - -def to_tl_type(ty): - return getattr(tl, str(ty).split(".")[-1]) - - -def get_configs_io_bound(): - configs = [] - for num_stages in [2, 3, 4, 5, 6]: - for block_m in [16, 32]: - for block_k in [32, 64]: - for block_n in [32, 64, 128, 256]: - num_warps = 2 if block_n <= 64 else 4 - configs.append( - Config( - { - "BLOCK_M": block_m, - "BLOCK_N": block_n, - "BLOCK_K": block_k, - "SPLIT_K": 1, - }, - num_stages=num_stages, - num_warps=num_warps, - ) - ) - # split_k - for split_k in [2, 4, 8, 16]: - configs.append( - Config( - { - "BLOCK_M": block_m, - "BLOCK_N": block_n, - "BLOCK_K": block_k, - "SPLIT_K": split_k, - }, - num_stages=num_stages, - num_warps=num_warps, - pre_hook=init_to_zero("C"), - ) - ) - return configs - - -def get_configs_compute_bound(): - configs = [ - # basic configs for compute-bound matmuls - Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, - num_stages=5, - num_warps=2, - ), - # good for int8 - Config( - {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - Config( - {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=3, - num_warps=8, - ), - Config( - {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=4, - num_warps=4, - ), - Config( - {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, - num_stages=5, - num_warps=2, - ), - ] - return configs - - -def get_autotuner( - configs=get_configs_compute_bound() + get_configs_io_bound(), - key=["M", "N", "K"], - early_config_prune=early_config_prune, - perf_model=estimate_matmul_time, - top_k=AUTOTUNER_TOP_K, -): - autotuner = autotune( - configs=configs, - key=key, - prune_configs_by={ - "early_config_prune": early_config_prune, - "perf_model": perf_model, - "top_k": top_k, - }, - ) - - return autotuner - - -def get_mm_heuristics(): - return heuristics( - { - "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, - } - ) - - -@triton.jit -def _matmul_kernel( - A, - B, - C, - M, - N, - K, # - stride_am, - stride_ak, # - stride_bk, - stride_bn, # - stride_cm, - stride_cn, # - # meta params - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, # - SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, - GROUP_M: tl.constexpr, - # epilogue - epilogue_alpha=None, - epilogue_beta=None, - epilogue_source=None, # Corresponds to C in GEMM convention of D = AB + C - # matmul kernel settings - acc_dtype: tl.constexpr = tl.float32, # - allow_tf32: tl.constexpr = True, # - fp8_fast_accum: tl.constexpr = True, # - AB_DTYPE: tl.constexpr = None, # - EPILOGUE: tl.constexpr = False, -): - # matrix multiplication - pid = tl.program_id(0) - pid_z = tl.program_id(1) - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) - # pointers - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) - for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - k_remaining = K - k * (BLOCK_K * SPLIT_K) - _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) - a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) - b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) - if AB_DTYPE is not None: - a = a.to(AB_DTYPE) - b = b.to(AB_DTYPE) - if fp8_fast_accum: - acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32) - else: - acc += tl.dot(a, b, out_dtype=acc_dtype, allow_tf32=allow_tf32) - A += BLOCK_K * SPLIT_K * stride_ak - B += BLOCK_K * SPLIT_K * stride_bk - # acc = acc.to(C.dtype.element_ty) - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - if EPILOGUE: - if epilogue_alpha is not None: - acc = epilogue_alpha.to(acc_dtype) * acc - if epilogue_source is not None: - epilogue_src = tl.load( - epilogue_source + rm[:, None] * stride_cm + rn[None, :] * stride_cn - ) - if epilogue_beta is not None: - epilogue_src = epilogue_src.to(acc_dtype) * epilogue_beta.to(acc_dtype) - acc = acc + epilogue_src - - acc = acc.to(C.dtype.element_ty) - C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) - mask = (rm < M)[:, None] & (rn < N)[None, :] - # handles write-back with reduction-splitting - if SPLIT_K == 1: - tl.store(C, acc, mask=mask) - else: - tl.atomic_add(C, acc, mask=mask) - - -_autotuner = get_autotuner() -_heuristics = get_mm_heuristics() -matmul = _autotuner(_heuristics(_matmul_kernel)) - - -def triton_mm_launcher( - a, - b, - epilogue_alpha=None, - epilogue_beta=None, - epilogue_source=None, - allow_tf32=True, - fp8_fast_accum=True, - acc_dtype=None, - output_dtype=None, - kernel=matmul, -): - device = a.device - # handle non-contiguous inputs if necessary - # a = grad - # b = galore_proj.ortho_matrix.t() - if a.stride(0) > 1 and a.stride(1) > 1: - a = a.contiguous() - if b.stride(0) > 1 and b.stride(1) > 1: - b = b.contiguous() - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - M, K = a.shape - _, N = b.shape - - # common type between a and b - ab_dtype = get_higher_dtype(a.dtype, b.dtype) - - # allocates output - if output_dtype is None: - output_dtype = ab_dtype - - c = torch.empty((M, N), device=device, dtype=output_dtype) - - if acc_dtype is None: - acc_dtype = [ab_dtype][0] - else: - assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype" - assert acc_dtype in TRITON_ACC_TYPES[a.dtype], ( - "acc_dtype not compatible with the type of a" - ) - assert acc_dtype in TRITON_ACC_TYPES[b.dtype], ( - "acc_dtype not compatible with the type of b" - ) - - acc_dtype = to_tl_type(acc_dtype) - ab_dtype = to_tl_type(ab_dtype) - output_dtype = to_tl_type(output_dtype) - - # Tensor cores support input with mixed float8 types. - if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [ - tl.float8e4nv, - tl.float8e5, - ]: - ab_dtype = None - # launch kernel - # print( - # f"{__file__} triton matmul args: (AB dtype {ab_dtype}) (C dtype {c.dtype}) (allow_tf32 {allow_tf32}) (fp8_fast_accum {fp8_fast_accum})" - # ) - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), - META["SPLIT_K"], - ) - - matmul[grid]( - a, - b, - c, - M, - N, - K, # - a.stride(0), - a.stride(1), # - b.stride(0), - b.stride(1), # - c.stride(0), - c.stride(1), # - epilogue_alpha=epilogue_alpha, # - epilogue_beta=epilogue_beta, # - epilogue_source=epilogue_source, # - acc_dtype=acc_dtype, # - allow_tf32=allow_tf32, # - fp8_fast_accum=fp8_fast_accum, # - GROUP_M=8, - AB_DTYPE=ab_dtype, - EPILOGUE=any([epilogue_alpha, epilogue_beta, epilogue_source]), - ) - return c diff --git a/torchao/prototype/galore/optim/galore_torch.py b/torchao/prototype/galore/optim/galore_torch.py deleted file mode 100644 index 876c40292d..0000000000 --- a/torchao/prototype/galore/optim/galore_torch.py +++ /dev/null @@ -1,401 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -"""Reference implementations -See https://github.com/jiaweizzhao/GaLore/tree/master/galore_torch -""" - -# copy dependencies from transformers/optimization.py -import math -import warnings -from typing import Callable, Iterable, Tuple - -import torch -from bitsandbytes.optim.optimizer import Optimizer2State -from torch import nn -from torch.optim import Optimizer - - -class GaLoreProjector: - def __init__( - self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type="std" - ): - self.rank = rank - self.verbose = verbose - self.update_proj_gap = update_proj_gap - self.scale = scale - self.ortho_matrix = None - self.proj_type = proj_type - - def project(self, full_rank_grad, iter): - if self.proj_type == "std": - if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: - if self.ortho_matrix is None or iter % self.update_proj_gap == 0: - self.ortho_matrix = self.get_orthogonal_matrix( - full_rank_grad, self.rank, type="right" - ) - low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) - else: - if self.ortho_matrix is None or iter % self.update_proj_gap == 0: - self.ortho_matrix = self.get_orthogonal_matrix( - full_rank_grad, self.rank, type="left" - ) - low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) - elif self.proj_type == "reverse_std": - if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: - if self.ortho_matrix is None or iter % self.update_proj_gap == 0: - self.ortho_matrix = self.get_orthogonal_matrix( - full_rank_grad, self.rank, type="left" - ) - low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) - else: - if self.ortho_matrix is None or iter % self.update_proj_gap == 0: - self.ortho_matrix = self.get_orthogonal_matrix( - full_rank_grad, self.rank, type="right" - ) - low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) - elif self.proj_type == "right": - if self.ortho_matrix is None or iter % self.update_proj_gap == 0: - self.ortho_matrix = self.get_orthogonal_matrix( - full_rank_grad, self.rank, type="right" - ) - low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) - elif self.proj_type == "left": - if self.ortho_matrix is None or iter % self.update_proj_gap == 0: - self.ortho_matrix = self.get_orthogonal_matrix( - full_rank_grad, self.rank, type="left" - ) - low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) - elif self.proj_type == "full": - if self.ortho_matrix is None or iter % self.update_proj_gap == 0: - self.ortho_matrix = self.get_orthogonal_matrix( - full_rank_grad, self.rank, type="full" - ) - low_rank_grad = ( - torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) - @ self.ortho_matrix[1].t() - ) - - return low_rank_grad - - def project_back(self, low_rank_grad): - if self.proj_type == "std": - if low_rank_grad.shape[0] >= low_rank_grad.shape[1]: - full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) - else: - full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) - elif self.proj_type == "reverse_std": - if ( - low_rank_grad.shape[0] <= low_rank_grad.shape[1] - ): # note this is different from std - full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) - else: - full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) - elif self.proj_type == "right": - full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) - elif self.proj_type == "left": - full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) - elif self.proj_type == "full": - full_rank_grad = ( - torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1] - ) - - return full_rank_grad * self.scale - - # svd decomposition - def get_orthogonal_matrix(self, weights, rank, type): - module_params = weights - - if module_params.data.dtype != torch.float: - float_data = False - original_type = module_params.data.dtype - original_device = module_params.data.device - matrix = module_params.data.float() - else: - float_data = True - matrix = module_params.data - - U, s, Vh = torch.linalg.svd(matrix, full_matrices=False) - - # make the smaller matrix always to be orthogonal matrix - if type == "right": - # A = U[:, :rank] @ torch.diag(s[:rank]) - B = Vh[:rank, :] - - if not float_data: - B = B.to(original_device).type(original_type) - return B - elif type == "left": - A = U[:, :rank] - # B = torch.diag(s[:rank]) @ Vh[:rank, :] - if not float_data: - A = A.to(original_device).type(original_type) - return A - elif type == "full": - A = U[:, :rank] - B = Vh[:rank, :] - if not float_data: - A = A.to(original_device).type(original_type) - B = B.to(original_device).type(original_type) - return [A, B] - else: - raise ValueError("type should be left, right or full") - - -class AdamW(Optimizer): - """ - Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay - Regularization](https://arxiv.org/abs/1711.05101). - - Parameters: - params (`Iterable[nn.parameter.Parameter]`): - Iterable of parameters to optimize or dictionaries defining parameter groups. - lr (`float`, *optional*, defaults to 0.001): - The learning rate to use. - betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`): - Adam's betas parameters (b1, b2). - eps (`float`, *optional*, defaults to 1e-06): - Adam's epsilon for numerical stability. - weight_decay (`float`, *optional*, defaults to 0.0): - Decoupled weight decay to apply. - correct_bias (`bool`, *optional*, defaults to `True`): - Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). - no_deprecation_warning (`bool`, *optional*, defaults to `False`): - A flag used to disable the deprecation warning (set to `True` to disable the warning). - """ - - def __init__( - self, - params: Iterable[nn.parameter.Parameter], - lr: float = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), - eps: float = 1e-6, - weight_decay: float = 0.0, - correct_bias: bool = True, - no_deprecation_warning: bool = False, - ): - if not no_deprecation_warning: - warnings.warn( - "This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch" - " implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this" - " warning", - FutureWarning, - ) - if lr < 0.0: - raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") - if not 0.0 <= betas[0] < 1.0: - raise ValueError( - f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)" - ) - if not 0.0 <= betas[1] < 1.0: - raise ValueError( - f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)" - ) - if not 0.0 <= eps: - raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") - defaults = { - "lr": lr, - "betas": betas, - "eps": eps, - "weight_decay": weight_decay, - "correct_bias": correct_bias, - } - super().__init__(params, defaults) - - @torch.no_grad() - def step(self, closure: Callable = None): - """ - Performs a single optimization step. - - Arguments: - closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. - """ - loss = None - if closure is not None: - loss = closure() - - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - grad = p.grad - if grad.is_sparse: - raise RuntimeError( - "Adam does not support sparse gradients, please consider SparseAdam instead" - ) - - state = self.state[p] - - if "step" not in state: - state["step"] = 0 - - # GaLore Projection - if "rank" in group: - if "projector" not in state: - state["projector"] = GaLoreProjector( - group["rank"], - update_proj_gap=group["update_proj_gap"], - scale=group["scale"], - proj_type=group["proj_type"], - ) - - grad = state["projector"].project(grad, state["step"]) - - # State initialization - if "exp_avg" not in state: - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like(grad) - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like(grad) - - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - beta1, beta2 = group["betas"] - - state["step"] += 1 - - # Decay the first and second moment running average coefficient - # In-place operations to update the averages at the same time - exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) - denom = exp_avg_sq.sqrt().add_(group["eps"]) - - step_size = group["lr"] - if group["correct_bias"]: # No bias correction for Bert - bias_correction1 = 1.0 - beta1 ** state["step"] - bias_correction2 = 1.0 - beta2 ** state["step"] - step_size = ( - step_size * math.sqrt(bias_correction2) / bias_correction1 - ) - - # compute norm gradient - norm_grad = exp_avg / denom - - # GaLore Projection Back - if "rank" in group: - norm_grad = state["projector"].project_back(norm_grad) - - p.add_(norm_grad, alpha=-step_size) - - # Just adding the square of the weights to the loss function is *not* - # the correct way of using L2 regularization/weight decay with Adam, - # since that will interact with the m and v parameters in strange ways. - # - # Instead we want to decay the weights in a manner that doesn't interact - # with the m/v parameters. This is equivalent to adding the square - # of the weights to the loss with plain (non-momentum) SGD. - # Add weight decay at the end (fixed version) - if group["weight_decay"] > 0.0: - p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) - - return loss - - -class AdamW8bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - is_paged=False, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 8, - args, - min_8bit_size, - percentile_clipping, - block_wise, - is_paged=is_paged, - ) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - if not self.initialized: - self.check_overrides() - self.to_gpu() # needed for fairseq pure fp16 training - self.initialized = True - - # if self.is_paged: self.page_mng.prefetch_all() - for gindex, group in enumerate(self.param_groups): - for pindex, p in enumerate(group["params"]): - if p.grad is None: - continue - state = self.state[p] - - if "step" not in state: - state["step"] = 0 - - # GaLore Projection - if "rank" in group: - if "projector" not in state: - state["projector"] = GaLoreProjector( - group["rank"], - update_proj_gap=group["update_proj_gap"], - scale=group["scale"], - proj_type=group["proj_type"], - ) - - if "weight_decay" in group and group["weight_decay"] > 0: - # ensure that the weight decay is not applied to the norm grad - group["weight_decay_saved"] = group["weight_decay"] - group["weight_decay"] = 0 - - grad = state["projector"].project(p.grad, state["step"]) - - # suboptimal implementation - p.saved_data = p.data.clone() - p.data = grad.clone().to(p.data.dtype).to(p.data.device) - p.data.zero_() - p.grad = grad - - if "state1" not in state: - self.init_state(group, p, gindex, pindex) - - self.prefetch_state(p) - self.update_step(group, p, gindex, pindex) - torch.cuda.synchronize() - - # GaLore Projection Back - if "rank" in group: - p.data = p.saved_data.add_(state["projector"].project_back(p.data)) - - # apply weight decay - if "weight_decay_saved" in group: - p.data.add_( - p.data, alpha=-group["lr"] * group["weight_decay_saved"] - ) - group["weight_decay"] = group["weight_decay_saved"] - del group["weight_decay_saved"] - - if self.is_paged: - # all paged operation are asynchronous, we need - # to sync to make sure all tensors are in the right state - torch.cuda.synchronize() - - return loss diff --git a/torchao/prototype/galore/utils.py b/torchao/prototype/galore/utils.py deleted file mode 100644 index 6e9db05d30..0000000000 --- a/torchao/prototype/galore/utils.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -import torch - - -def get_orthogonal_matrix(weights, rank, type): - module_params = weights - - if module_params.data.dtype != torch.float: - float_data = False - original_type = module_params.data.dtype - original_device = module_params.data.device - matrix = module_params.data.float() - else: - float_data = True - matrix = module_params.data - - U, s, Vh = torch.linalg.svd(matrix, full_matrices=False) - - # make the smaller matrix always to be orthogonal matrix - if type == "right": - # A = U[:, :rank] @ torch.diag(s[:rank]) - B = Vh[:rank, :] - - if not float_data: - B = B.to(original_device).type(original_type) - return B - elif type == "left": - A = U[:, :rank] - # B = torch.diag(s[:rank]) @ Vh[:rank, :] - if not float_data: - A = A.to(original_device).type(original_type) - return A - elif type == "full": - A = U[:, :rank] - B = Vh[:rank, :] - if not float_data: - A = A.to(original_device).type(original_type) - B = B.to(original_device).type(original_type) - return [A, B] - else: - raise ValueError("type should be left, right or full") - - -class TestGaLoreProjector: - def __init__( - self, - rank=128, - scale=1.0, - proj_type="std", - ): - self.rank = rank - self.scale = scale - - if proj_type != "std": - raise ("Only std projection is supported") - - self.proj_type = proj_type - - self.ortho_matrix = None - - def update_orthogonal_matrix(self, full_rank_grad): - if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: - self.ortho_matrix = get_orthogonal_matrix( - full_rank_grad, self.rank, type="right" - ) - else: - self.ortho_matrix = get_orthogonal_matrix( - full_rank_grad, self.rank, type="left" - ) - - def project(self, full_rank_grad): - if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: - low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) - else: - low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) - - return low_rank_grad - - def project_back(self, low_rank_grad): - if low_rank_grad.shape[0] >= low_rank_grad.shape[1]: - full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) - else: - full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) - - return full_rank_grad * self.scale - - -def make_copy(*args): - return [t.detach().clone() for t in args] - - -# def adam_step( -# exp_avg, -# exp_avg2, -# grad, -# galore_proj, -# params, -# step_size=1e-4, -# beta1=BETA1, -# beta2=BETA2, -# eps=EPS, -# ): -# grad = galore_proj.project(grad) -# exp_avg = beta1 * exp_avg + (1 - beta1) * grad -# exp_avg2 = beta2 * exp_avg2 + (1 - beta2) * torch.square(grad) -# denom = exp_avg2.sqrt() + eps -# norm_grad = exp_avg / denom -# norm_grad = galore_proj.project_back(norm_grad) -# # params = params - step_size * norm_grad -# return exp_avg, exp_avg2, denom, norm_grad