Skip to content
23 changes: 17 additions & 6 deletions applications/ColossalChat/coati/distributed/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
model_config: Dict[str, Any],
plugin_config: Dict[str, Any],
microbatch_size: int = 1,
pp_batch_size: int = 8,
save_interval: int = 100,
save_dir: str = "./model",
):
Expand All @@ -54,7 +55,13 @@ def __init__(

self.model_config = model_config
self.plugin_config = plugin_config
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"

# To support pipeline parallel,
# we use (train) microbatch_size as pp batch size.
# So, the pp microbatch size = microbatch_size// pp size
self.pp_microbatch_size = pp_batch_size // self.plugin_config.get("pp_size", 1)
self.pp_num_microbatches = pp_batch_size // self.pp_microbatch_size
# assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"

self.device = get_current_device()

Expand All @@ -66,13 +73,18 @@ def setup(self) -> None:
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)

plugin_config = dict(
tp_size=1,
pp_size=1,
tp_size=self.plugin_config.get("tp_size", 1),
pp_size=self.plugin_config.get("pp_size", 1),
# microbatch_size=self.pp_microbatch_size,
num_microbatches=self.pp_num_microbatches,
precision="bf16",
zero_stage=1,
enable_flash_attention=True,
)
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
plugin_config["microbatch_size"] = self.microbatch_size

# if plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
# # plugin_config["microbatch_size"] = self.microbatch_size
# plugin_config["num_microbatches"] = plugin_config.get("pp_size", 1)
plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin)
Expand All @@ -99,7 +111,6 @@ def loop(self) -> None:
i = 0
for _ in range(self.num_recv_per_update):
# receive data from producers

for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
self.buffer.extend(
Expand Down
126 changes: 102 additions & 24 deletions applications/ColossalChat/coati/distributed/grpo_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@

import ray
import torch
import torch.distributed as dist
import wandb
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs
from coati.distributed.utils import calc_action_log_probs, filter_microbatch_dicts, split_into_microbatches
from coati.trainer.utils import all_reduce_mean
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam


Expand All @@ -31,6 +34,7 @@ def __init__(
model_config,
plugin_config,
microbatch_size=1,
pp_batch_size=8,
num_generations=4,
use_wandb=True,
):
Expand All @@ -47,6 +51,7 @@ def __init__(
model_config,
plugin_config,
microbatch_size,
pp_batch_size,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
Expand Down Expand Up @@ -86,10 +91,13 @@ def __init__(
if use_wandb and self.rank == 0:
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True)

self.coordinator = None

def setup(self):
super().setup()
self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer)
self.reference_model, *_ = self.booster.boost(self.reference_model)
self.coordinator = DistCoordinator()

def step(self, step_idx: int, **kwargs) -> Optional[float]:
"""
Expand All @@ -106,31 +114,103 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
"""

# Reshape to [batch_size x num_of_generation, prompt_length + response_length]
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
action_mask = data["action_mask"]
num_action = action_mask.shape[1]
old_action_log_probs = data["action_log_probs"]
response_length = torch.sum(action_mask, dim=1).to(torch.float32)

need_update = (step_idx + 1) % self.num_microbatches == 0

ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
ctx = nullcontext()
# ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
with ctx:
policy_model_logits = self.policy_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
# print(f"Before split Rank {dist.get_rank()}] \
# input_ids {data['input_ids'].shape} \
# attention_mask {data['attention_mask'].shape} \
# action_mask {data['action_mask'].shape} \
# gt_answer {data['gt_answer'].shape}\ ")

data_iter = split_into_microbatches(data, self.pp_microbatch_size) # self.pp_num_microbatches

# print(f"After split Rank {dist.get_rank()}] \
# input_ids {data_iter[0]['input_ids'].shape} \
# attention_mask {data_iter[0]['attention_mask'].shape} \
# action_mask {data_iter[0]['action_mask'].shape} \
# gt_answer {data_iter[0]['gt_answer'].shape}\ ")

input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
action_mask = data["action_mask"]
num_action = action_mask.shape[1]
old_action_log_probs = data["action_log_probs"]
gt_answer = data["gt_answer"]
response_idx = data["response_idx"]
response_length = torch.sum(action_mask, dim=1).to(torch.float32)

policy_model_logits = None
reference_model_logits = None
if self.booster.plugin.pp_size > 1:
# allowed_keys = ("input_ids", "attention_mask")
# data_iter = [{key: value for key, value in data.items() if key in allowed_keys}]
data_iter = filter_microbatch_dicts(data_iter)
# We don't have to iter data_iter, cause data_iter means a microbatch now.
step_bar = tqdm(
range(len(data_iter)),
desc="Step",
disable=not self.coordinator.rank == self.coordinator.world_size - 1,
)
# You must init two data iter for policy model and inference model respectively. or you will get next(data_iter) out of idx.
data_iter, data_iter_infer = iter(data_iter), iter(data_iter)
for step in step_bar:
policy_model_output = self.booster.execute_pipeline(
data_iter,
self.policy_model,
criterion=lambda x, y: x.logits.mean(),
optimizer=self.optimizer,
return_loss=False,
return_outputs=True,
)

with torch.no_grad():
reference_model_output = self.booster.execute_pipeline(
data_iter_infer,
self.reference_model,
criterion=lambda x, y: x.logits.mean(),
return_loss=False,
return_outputs=True,
)

if self.booster.plugin.stage_manager.is_last_stage():
local_policy_model_logits = policy_model_output["outputs"]["logits"]
local_reference_model_logits = reference_model_output["outputs"]["logits"]
if step == 0:
policy_model_logits = local_policy_model_logits
reference_model_logits = local_reference_model_logits
else:
policy_model_logits = torch.cat((policy_model_logits, local_policy_model_logits), dim=0)
reference_model_logits = torch.cat(
(reference_model_logits, local_reference_model_logits), dim=0
)
if self.booster.plugin.stage_manager.is_last_stage():
print(
f"Rank {dist.get_rank()} step {step} policy_model_logits {policy_model_logits.shape} {policy_model_logits} reference_model_logits {reference_model_logits.shape} {reference_model_logits}"
)

else:
policy_model_logits = self.policy_model(
input_ids=input_ids,
attention_mask=attention_mask,
)["logits"]

with torch.no_grad():
reference_model_logits = self.reference_model(
input_ids=input_ids,
attention_mask=attention_mask,
)["logits"]

action_log_probs = calc_action_log_probs(
policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config
policy_model_logits, input_ids, num_action, self.plugin.shard_config
)

with torch.no_grad():
reference_model_logits = self.reference_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
reference_action_log_probs = calc_action_log_probs(
reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config
reference_model_logits, input_ids, num_action, self.plugin.shard_config
)

per_token_kl = (
Expand All @@ -140,13 +220,11 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
)
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)

reward_group = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
reward_group = self.reward_model(input_ids, gt_answer=gt_answer, response_idx=response_idx)

reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
reward = torch.tensor([value[0] for value in reward_group]).to(input_ids.device)
format_reward = torch.tensor([value[1] for value in reward_group]).to(input_ids.device)
acc_reward = torch.tensor([value[2] for value in reward_group]).to(input_ids.device)

# [batch_size, num_generations]
group_reward = reward.view(-1, self.num_generations)
Expand Down
2 changes: 2 additions & 0 deletions applications/ColossalChat/coati/distributed/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def launch_distributed(
inference_microbatch_size: int,
train_batch_size: int,
train_microbatch_size: int,
pp_batch_size: int,
dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any],
inference_model_config: Dict[str, Any],
Expand Down Expand Up @@ -94,6 +95,7 @@ def launch_distributed(
model_config=train_model_config,
plugin_config=plugin_config,
microbatch_size=train_microbatch_size,
pp_batch_size=pp_batch_size,
)
procs.append(consumer)
ray.get([p.setup.remote() for p in procs])
Expand Down
45 changes: 45 additions & 0 deletions applications/ColossalChat/coati/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,51 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor
return batch


def split_into_microbatches(data_dict, microbatch_size):
"""
将包含多个张量的字典根据 microbatch_size 切分成多个微批次字典。
:param data_dict: 包含多个张量的字典,input_ids 形状为 (batch_size, seq_len, hidden_dim)
:param microbatch_size: 每个微批次的大小
:return: 微批次字典列表
"""
batch_size = next(iter(data_dict.values())).size(0)
microbatch_dicts = []

for start_idx in range(0, batch_size, microbatch_size):
end_idx = min(start_idx + microbatch_size, batch_size)
microbatch_dict = {}
for key, tensor in data_dict.items():
if tensor.size(0) == batch_size:
microbatch_dict[key] = tensor[start_idx:end_idx]
else:
microbatch_dict[key] = tensor
microbatch_dicts.append(microbatch_dict)

return microbatch_dicts


def cyclic_iter(dataloader):
epoch = 0
while True:
for batch in dataloader:
yield batch
epoch += 1


def filter_microbatch_dicts(microbatch_dicts):
"""
遍历 microbatch_dicts 列表,移除每个字典中键不在 ("input_ids", "attention_mask") 范围内的键值对
:param microbatch_dicts: 包含多个字典的列表
:return: 过滤后的 microbatch_dicts 列表
"""
filtered_dicts = []
allowed_keys = ("input_ids", "attention_mask")
for microbatch_dict in microbatch_dicts:
filtered_dict = {key: value for key, value in microbatch_dict.items() if key in allowed_keys}
filtered_dicts.append(filtered_dict)
return filtered_dicts


def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# compress mask to save bandwidth
if "attention_mask" in batch:
Expand Down
10 changes: 7 additions & 3 deletions applications/ColossalChat/rl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8)
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1)
parser.add_argument("-ppmbs", "--pp-batch-size", type=int, default=8)
parser.add_argument("-b", "--backend", type=str, default="transformers")
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GRPO"])
args = parser.parse_args()
Expand All @@ -31,13 +32,15 @@
if args.backend == "transformers":
inference_model_config.update(
dict(
use_flash_attention_2=True,
# use_flash_attention_2=True,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
)
train_model_config.update(
dict(
use_flash_attention_2=True,
# use_flash_attention_2=True,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
use_cache=False,
)
Expand Down Expand Up @@ -89,12 +92,13 @@
inference_microbatch_size=args.inference_microbatch_size,
train_batch_size=args.train_batch_size,
train_microbatch_size=args.train_microbatch_size,
pp_batch_size=args.pp_batch_size,
dataset_config={"path": args.dataset, "max_length": 300},
dataloaders_config={},
inference_model_config=inference_model_config,
generate_config=generate_config,
train_model_config=train_model_config,
plugin_config={},
plugin_config={"tp_size": 1, "pp_size": 2},
inference_backend=args.backend,
master_addr="localhost",
master_port=29504,
Expand Down
10 changes: 6 additions & 4 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,10 +1416,12 @@ def execute_pipeline(
):
return outputs

# Synchronize the grads of shared parameters of the model.
model.sync_shared_params()
# Synchronize sequence parallelism gradients of the model.
model.sync_sp_grads()
# Synchronize when training
if torch.is_grad_enabled():
# Synchronize the grads of shared parameters of the model.
model.sync_shared_params()
# Synchronize sequence parallelism gradients of the model.
model.sync_sp_grads()

# Check if the optimizer is a HybridParallelZeroOptimizer and synchronize data parallelism gradients if so.
# Otherwise, synchronize data parallelism gradients of the model.
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def dist_log_prob(
dtype=dtype,
)
else:
log_prob = log_softmax(logits)
log_prob = log_softmax(logits, dim=-1)
log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1))

return log_prob
Loading