def get_random_microbatch(
microbatch_size: int, max_sequence_length: int, vocab_size: int, seed: int
) -> Dict[str, Dict[str, torch.Tensor]]:
"""Generate random microbatches for testing.
Note that this follows the convention that token_logits are s,b, while other fields are b,s.
"""
generator = torch.Generator(device=torch.cuda.current_device()).manual_seed(seed)
labels = torch.randint(
low=0,
high=vocab_size,
size=(microbatch_size, max_sequence_length),
generator=generator,
device=torch.cuda.current_device(),
) # [b s]
loss_mask = torch.randint(
low=1,
high=1 + 1,
size=(microbatch_size, max_sequence_length),
dtype=torch.long,
device=torch.cuda.current_device(),
generator=generator,
) # [b s]
token_logits = torch.rand(
max_sequence_length, microbatch_size, vocab_size, device=torch.cuda.current_device(), generator=generator
) # [s b v]
labels[loss_mask == 0] = -100 # propagate masking to labels
microbatch_output = {
"batch": {"labels": labels, "loss_mask": loss_mask},
"forward_out": {"token_logits": token_logits},
}
return microbatch_output