# Maximum value along vocab dimension across all GPUs. logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] torch.distributed.all_reduce( logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group() ) # Subtract the maximum value. vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1)
# Create a mask of valid vocab ids (1 means it needs to be masked). target_mask = (target < vocab_start_index) | (target >= vocab_end_index) masked_target = target.clone() - vocab_start_index masked_target[target_mask] = 0
# Get predicted-logits = logits[target]. # For Simplicity, we convert logits to a 2-D tensor with size # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) masked_target_1d = masked_target.view(-1) arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] predicted_logits_1d = predicted_logits_1d.clone().contiguous() predicted_logits = predicted_logits_1d.view_as(target) predicted_logits[target_mask] = 0.0 # All reduce is needed to get the chunks from other GPUs. torch.distributed.all_reduce( predicted_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group(), )
# Sum of exponential of logits along vocab dimension across all GPUs. exp_logits = vocab_parallel_logits torch.exp(vocab_parallel_logits, out=exp_logits) sum_exp_logits = exp_logits.sum(dim=-1) torch.distributed.all_reduce( sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group(), )
# Loss = log(sum(exp(logits))) - predicted-logit. loss = torch.log(sum_exp_logits) - predicted_logits
# Normalize and optionally smooth logits exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
# Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs. log_probs = torch.log(exp_logits) mean_log_probs = log_probs.mean(dim=-1) loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs
# Store softmax, target-mask and masked-target for backward pass. ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss
@staticmethod defbackward(ctx, grad_output):
# Retreive tensors from the forward path. softmax, target_mask, masked_target_1d = ctx.saved_tensors label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size
# All the inputs have softmax as thier gradient. grad_input = softmax # For simplicity, work with the 2D gradient. partition_vocab_size = softmax.size()[-1] grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes. arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
# Finally elementwise multiplication with the output gradients. grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None, None
defvocab_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0): """ Performs cross entropy loss when logits are split across tensor parallel ranks Arguments: vocab_parallel_logits: logits split across tensor parallel ranks dimension is [sequence_length, batch_size, hidden_size] target: correct vocab ids of dimseion [sequence_length, micro_batch_size] lobal_smoothing: smoothing factor, must be in range [0.0, 1.0) default is no smoothing (=0.0) """ return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, label_smoothing)
predicted_logits = predicted_logits_1d.view_as(target) predicted_logits[target_mask] = 0.0 # All reduce is needed to get the chunks from other GPUs. torch.distributed.all_reduce( predicted_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group(), )
# 根据SoftMax计算的结果,忽略无需计算梯度的位置。target_mask中的元素不是0就是1 grad_input = softmax # For simplicity, work with the 2D gradient. partition_vocab_size = softmax.size()[-1] grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes. arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)