-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Attention across documents. #213
base: main
Are you sure you want to change the base?
Conversation
@@ -742,6 +742,11 @@ def parse_args(args): | |||
action="store_true", | |||
help="If set, allow model to do multiple data passes over our dataset, in order to reach the desired number of tokens.", | |||
) | |||
parser.add_argument( | |||
"--mask-across-documents", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this should be an int
not a bool
so that a user can specify their EOT token
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense - will update the parameter.
if args.mask_across_documents: | ||
# Some input samples contain EOT as the final token. The prediction after that is meaningless, so it | ||
# should not contribute to the loss. | ||
ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT.value, as_tuple=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i prefer not to hard code our EOT to keep open_lm tokenizer agnostic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed - I'll change it so that it uses the user defined EOT token.
* Update .gitignore * Fix requirements for env. * Remove test data prep file erroneously committed. * Revert requirements. * Update makefile. --------- Co-authored-by: George Smyrnis <gsmyrnis@utexas.edu>
4c322d1
to
7234b31
Compare
# Some input samples contain EOT as the final token. The prediction after that is meaningless, so it | ||
# should not contribute to the loss. | ||
ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT.value, as_tuple=True) | ||
targets = targets.detach().clone() # Clone this because it shares mem with input! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, is the detach necessary here? When args.mask_across_documents is False, should we also a detach()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Detach is not necessary, but clone is - because the targets and the input share the underlying tensor, if the target is explicitly set then the input is also affected.
When args.mask_across_documents is False, this is not an issue - neither the target nor the input are explicitly changed.
This adds a flag that stops attention from going across documents, identified by the EOT token.
The loss for the token right after the EOT token is ignored.
TODO: add some tests for the shape of the mask.