Closed dgcnz closed 1 month ago
Relevant changes:
model/wang2022/full.yaml
experiments/wang2022/equivariance_test/full.yaml
autoregressive_train
class Wang2022LightningModule(LightningModule): def __init__( self, net: torch.nn.Module, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, compile: bool, autoregressive_train: bool = True, ) -> None:
Relevant changes:
model/wang2022/full.yaml
and example config atexperiments/wang2022/equivariance_test/full.yaml
autoregressive_train
flag to wang2022 module