john-hewitt / structural-probes

Codebase for testing whether hidden states of neural networks encode discrete structures.
Other
381 stars 77 forks source link

segment type ids should be zeros instead of ones (minor update suggestion) #13

Open caspillaga opened 2 years ago

caspillaga commented 2 years ago

Just for the record, in case someone finds it useful or plans to extend it.

In line 48 of https://github.com/john-hewitt/structural-probes/blob/4c2e265d6bd071e6ab380fd9806e4c6a128b5e97/scripts/convert_raw_to_bert.py#L48 Segment type ids should be zeros, not ones as implemented there (sentence A = 0, sentence B = 1) I believe this will not make much difference anyway. Moreover, in the new huggingface library's API this parameter can be ignored and the library creates it automatically, as seen in the code below.

In case someone finds it useful, I also updated the code to a version compatible with the updated library (transformers)

The relevant lines that changed are these (some lines ignored for clarity):

from transformers import BertTokenizerFast, BertModel

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizerFast.from_pretrained(...)
model = BertModel.from_pretrained(....)
LAYER_COUNT = 12+1 # 24+1 for bert-large
FEATURE_COUNT = 768 # 1024 for bert-large
model.eval()

# tokenize text, preserving PTB tokenized words
indexed_tokens = tokenizer._batch_encode_plus(line.split(), add_special_tokens=False, return_token_type_ids=False, return_attention_mask=False)
indexed_tokens = [item for sublist in indexed_tokens['input_ids'] for item in sublist]
indexed_tokens = tokenizer.build_inputs_with_special_tokens(indexed_tokens) # Add [CLS] and [SEP]

# Build batch and run the model
tokens_tensor = torch.tensor([indexed_tokens])
with torch.no_grad():
    encoded_layers = model(input_ids=tokens_tensor, output_hidden_states=True)['hidden_states']

# Notice that index and fout comes from the loop in the original code, ignored here for clarity
dset = fout.create_dataset(str(index), (LAYER_COUNT, len(indexed_tokens), FEATURE_COUNT))
dset[:,:,:] = np.vstack([np.array(x) for x in encoded_layers])