srush / Tensor-Puzzles

Solve puzzles. Improve your pytorch.
MIT License
2.92k stars 240 forks source link

Help with puzzles 4 and 5 #9

Open SnehaMondal opened 1 year ago

SnehaMondal commented 1 year ago

I'm trying to solve puzzles 4 and 5 by using the predefined where(q, a, b) function, which expects q to be a boolean tensor. To arrive at a boolean tensor, I create a boolean list and use tensor on it. I suppose this is not allowed? How else could I implement this?

This is my current implementation for puzzle 5 (identity matrix of dimension j)

def eye(j: int) -> TT["j", "j"]:
   return where(tensor([[x==y for x in range(j)] for y in range(j)]), ones(j) - ones(j)[:, None] + ones(j), ones(j) - ones(j)[:, None])
srush commented 1 year ago

Everything you are currently doing with list comprehensions, you can do faster and simpler with the arange function .

For example if one does arange(10) == 2 that gives you a tensor that is all False with a True at position 2.

atgctg commented 1 year ago

After some experimentation, I came up with the following:

def eye(j: int) -> TT["j", "j"]:
   return where(arange(j) - arange(j)[:, None] == 0, 1, 0)

Wonder if there's even a better way...

srush commented 1 year ago

That's great! There are a couple ways to do it, but this is the way I did it.

The other thing I saw was to multiply to force type changes something like:

(arange(n)==arange(n)[:, None]) * 1

CreativeSelf0 commented 1 year ago

I really loved, the way you thought about it, it's out of the box thinking :P. This is how I solved it. It would be cool if you have your solutions, out there. Although, it would incentives people to maybe give up easily. But person who gave up after trying would still learn. a = tensor([0] * j) index = arange(j) data = outer(a,a) data[index, index] = 1 return data

jselvam11 commented 10 months ago

where(arange(j)[:, None] == arange(j), (outer(ones(j), ones(j))), 0)

CharlesDDDD commented 2 weeks ago

return ((arange(j) + arange(j)[:,None]) == diag((arange(j) + arange(j)[:,None])))*1

Only after looking at your guys solutions did I realize how inelegant I am 🥹