diff --git a/README.md b/README.md index dde75e208..567851124 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,7 @@ Once you have confirmed access, you can run the following command to download th ```bash # Get your HF token from https://huggingface.co/settings/tokens -# llama3 tokenizer.model +# llama3 or 3.1 tokenizer.model python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3-8B --tokenizer_path "original" --hf_token=... # llama2 tokenizer.model diff --git a/torchtitan/datasets/download_tokenizer.py b/torchtitan/datasets/download_tokenizer.py index 44ef5f59e..a419d7090 100644 --- a/torchtitan/datasets/download_tokenizer.py +++ b/torchtitan/datasets/download_tokenizer.py @@ -20,8 +20,8 @@ def hf_download( try: hf_hub_download( - repo_id, - tokenizer_path, + repo_id=repo_id, + filename=tokenizer_path, local_dir=local_dir, local_dir_use_symlinks=False, token=hf_token, diff --git a/torchtitan/models/llama/__init__.py b/torchtitan/models/llama/__init__.py index 3cdfe0f99..887a96cdc 100644 --- a/torchtitan/models/llama/__init__.py +++ b/torchtitan/models/llama/__init__.py @@ -48,4 +48,13 @@ multiple_of=4096, rope_theta=500000, ), + "405B": ModelArgs( + dim=16384, + n_layers=126, + n_heads=128, + n_kv_heads=8, + ffn_dim_multiplier=1.2, + multiple_of=4096, + rope_theta=500000, + ), } diff --git a/train_configs/llama3_405b.toml b/train_configs/llama3_405b.toml new file mode 100644 index 000000000..fb250642e --- /dev/null +++ b/train_configs/llama3_405b.toml @@ -0,0 +1,53 @@ +# torchtitan Config.toml +# NOTE: this toml config is a preset for 128 H100 GPUs. + +[job] +dump_folder = "./outputs" +description = "Llama 3 405B training" + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = true +save_tb_folder = "tb" + +[model] +name = "llama3" +flavor = "405B" +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm +tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" + +[optimizer] +name = "AdamW" +lr = 0.8e-4 + +[training] +batch_size = 2 +seq_len = 8192 +warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps +max_norm = 1.0 # grad norm clipping +steps = 3000 +data_parallel_degree = -1 +tensor_parallel_degree = 8 # 8-way TP +enable_float8_linear = false +compile = false +dataset = "c4" + +[experimental] +pipeline_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval_type = "steps" +interval = 500 +model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'full' # ['none', 'selective', 'full']