-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathscripts.py
61 lines (51 loc) · 3.07 KB
/
scripts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from torch import nn
import torch.nn.functional as F
import math
import torch
from pytorch_transformers import *
# PyTorch-Transformers has a unified API
# for 7 transformer architectures and 30 pretrained weights.
# Model | Tokenizer | Pretrained weights shortcut
MODELS = [(BertModel, BertTokenizer, 'bert-base-uncased'),
(OpenAIGPTModel, OpenAIGPTTokenizer, 'openai-gpt'),
(GPT2Model, GPT2Tokenizer, 'gpt2'),
(TransfoXLModel, TransfoXLTokenizer, 'transfo-xl-wt103'),
(XLNetModel, XLNetTokenizer, 'xlnet-base-cased'),
(XLMModel, XLMTokenizer, 'xlm-mlm-enfr-1024'),
(RobertaModel, RobertaTokenizer, 'roberta-base')]
# Let's encode some text in a sequence of hidden-states using each model:
for model_class, tokenizer_class, pretrained_weights in MODELS:
# Load pretrained model/tokenizer
tokenizer = tokenizer_class.from_pretrained(pretrained_weights)
model = model_class.from_pretrained(pretrained_weights)
# Encode text
input_ids = torch.tensor([tokenizer.encode("Here is some text to encode", add_special_tokens=True)]) # Add special tokens takes care of adding [CLS], [SEP], <s>... tokens in the right way for each model.
with torch.no_grad():
last_hidden_states = model(input_ids)[0] # Models outputs are now tuples
# Each architecture is provided with several class for fine-tuning on down-stream tasks, e.g.
BERT_MODEL_CLASSES = [BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification,
BertForQuestionAnswering]
# All the classes for an architecture can be initiated from pretrained weights for this architecture
# Note that additional weights added for fine-tuning are only initialized
# and need to be trained on the down-stream task
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
for model_class in BERT_MODEL_CLASSES:
# Load pretrained model/tokenizer
model = model_class.from_pretrained('bert-base-uncased')
# Models can return full list of hidden-states & attentions weights at each layer
model = model_class.from_pretrained(pretrained_weights,
output_hidden_states=True,
output_attentions=True)
input_ids = torch.tensor([tokenizer.encode("Let's see all hidden-states and attentions on this text")])
all_hidden_states, all_attentions = model(input_ids)[-2:]
# Models are compatible with Torchscript
model = model_class.from_pretrained(pretrained_weights, torchscript=True)
traced_model = torch.jit.trace(model, (input_ids,))
# Simple serialization for models and tokenizers
model.save_pretrained('./directory/to/save/') # save
model = model_class.from_pretrained('./directory/to/save/') # re-load
tokenizer.save_pretrained('./directory/to/save/') # save
tokenizer = tokenizer_class.from_pretrained('./directory/to/save/') # re-load
# SOTA examples for GLUE, SQUAD, text generation...