Closed vieting closed 2 years ago
In #110, I added a dummy run to test a serialized config. In the example given in test_negative_sampling, this does not yet work. The error can be reproduced by adding the dummy run to that test, i.e.:
test_negative_sampling
def test_negative_sampling(): n_batch, n_time, n_feat = 3, 14, 7 # B, T', F n_negatives = 10 # N def model_func(wrapped_import, inputs: torch.Tensor): if typing.TYPE_CHECKING or not wrapped_import: import torch else: torch = wrapped_import("torch") model = torch.nn.Conv1d(in_channels=n_feat, out_channels=n_feat, kernel_size=2, stride=3) inputs = model(inputs.transpose(1, 2)).transpose(1, 2).contiguous() bsz, tsz, fsz = inputs.shape # (B,T,F) tszs = torch.arange(tsz).unsqueeze(-1).expand(-1, n_negatives).flatten() # (T*N) neg_idxs = torch.randint(low=0, high=tsz - 1, size=(bsz, n_negatives * tsz)) # (B,T*N) neg_idxs = neg_idxs + (neg_idxs >= tszs).int() # (B,T*N) neg_idxs = neg_idxs + (torch.arange(bsz).unsqueeze(1) * tsz) # (B,T*N) y = inputs.view(-1, fsz) # (B,T,F) => (B*T,F) negs = y[neg_idxs.view(-1)] # (B*T*N,F) negs = negs.view(bsz, tsz, n_negatives, fsz).permute(2, 0, 1, 3) # to (N,B,T,F) inputs_unsqueeze = inputs.unsqueeze(0) # (1,B,T,F) targets = torch.cat([inputs_unsqueeze, negs], dim=0) # (N+1,B,T,F) logits = torch.cosine_similarity(inputs.float(), targets.float(), dim=-1).type_as(inputs) return logits rnd = numpy.random.RandomState(42) x = rnd.normal(0., 1., (n_batch, n_time, n_feat)).astype("float32") converter = verify_torch_and_convert_to_returnn(model_func, inputs=x, inputs_data_kwargs={ "shape": (None, n_feat), "batch_dim_axis": 0, "time_dim_axis": 1, "feature_dim_axis": 2}) cfg = converter.get_returnn_config_serialized() from returnn_helpers import config_net_dict_via_serialized, dummy_run_net config, net_dict = config_net_dict_via_serialized(cfg) dummy_run_net(config)
I'll add a draft PR and add the log and stack trace here once the test finish.
See #112 and corresponding tests (here)
Traceback
In #110, I added a dummy run to test a serialized config. In the example given in
test_negative_sampling
, this does not yet work. The error can be reproduced by adding the dummy run to that test, i.e.: