spfrommer / torchexplorer

Interactively inspect module inputs, outputs, parameters, and gradients.
https://spfrommer.github.io/torchexplorer/
Apache License 2.0
310 stars 22 forks source link

visualize GPT2 architecture only input and output #54

Open datafireball opened 3 months ago

datafireball commented 3 months ago

I am following the tutorial from Andrej K. building gpt2 from scratch. I thought it would be a good idea to visualize his GPT2 model using torchexplorer.

This what I did:

  1. install torchexplorer (windows10, used the pygraphviz from alubbock channel)
  2. placed the torchexplorer.watch(model, log=['io', 'params'], backend='standalone') before calling the model.
  3. run one step0 with a full forward and backward pass.
  4. check out localhost:8080

However, it seems to only capture the input and output excluding the whole network, any thought?

image

I started with his first commit in the project to avoid more complex operations like DDP, context manager etc. Here is the code.

class DataLoaderLite:
    def __init__(self, B, T):
        self.B = B
        self.T = T

        # at init load tokens from disk and store them in memory
        with open('input.txt', 'r') as f:
            text = f.read()
        enc = tiktoken.get_encoding('gpt2')
        tokens = enc.encode(text)
        self.tokens = torch.tensor(tokens)
        print(f"loaded {len(self.tokens)} tokens")
        print(f"1 epoch = {len(self.tokens) // (B * T)} batches")

        # state
        self.current_position = 0

    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1]
        x = (buf[:-1]).view(B, T) # inputs
        y = (buf[1:]).view(B, T) # targets
        # advance the position in the tensor
        self.current_position += B * T
        # if loading the next batch would be out of bounds, reset
        if self.current_position + (B * T + 1) > len(self.tokens):
            self.current_position = 0
        return x, y

import tiktoken
# -----------------------------------------------------------------------------
# attempt to autodetect the device
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"
print(f"using device: {device}")

torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1337)

train_loader = DataLoaderLite(B=4, T=32)

# get logits
model = GPT(GPTConfig())
model.to(device)

# optimize!
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

torchexplorer.watch(model, log=['io', 'params'], backend='standalone')

for i in range(50):
    x, y = train_loader.next_batch()
    x, y = x.to(device), y.to(device)
    optimizer.zero_grad()
    logits, loss = model(x, y)
    loss.backward()
    optimizer.step()
    print(f"step {i}, loss: {loss.item()}")
    break
spfrommer commented 2 months ago

I can confirm that I've been able to reproduce this -- unfortunately, I don't have the bandwidth this summer to look into what's happening with this architecture. Thank you for the bug report in any case.