Hi team - a question about the following line of code:
|
mask_prob = self.args.mask_prob + (.8 - self.args.mask_prob) * (epoch - 1) / 20 |
Can I ask what function this serves, and why? If resuming training from a checkpoint model where the epoch count exceeds 28, this is greater than 1 and causes assertion fails. I note that if starting training from epoch 0 with no interruption, the mask_prob remains fixed at the initialized value (0.2), but with any interruption the training resumes with a higher mask prob.
Should this just be mask_prob = self.args.mask_prob ?
Hi team - a question about the following line of code:
trex/fairseq/tasks/trex.py
Line 144 in b4b7b97
Can I ask what function this serves, and why? If resuming training from a checkpoint model where the epoch count exceeds 28, this is greater than 1 and causes assertion fails. I note that if starting training from epoch 0 with no interruption, the mask_prob remains fixed at the initialized value (0.2), but with any interruption the training resumes with a higher mask prob.
Should this just be
mask_prob = self.args.mask_prob?