the torch can use the GPU,like the following code.
import torch
import math
#this ensures that the current MacOS version is at least 12.3+
print(torch.backends.mps.is_available())
#this ensures that the current current PyTorch installation was built with MPS activated.
print(torch.backends.mps.is_built())
dtype = torch.float
device = torch.device("mps")
#Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)
#Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)
learning_rate = 1e-6
for t in range(2000):
# Forward pass: compute predicted y
y_pred = a + b * x + c * x ** 2 + d * x ** 3
#Compute and print loss
loss = (y_pred - y).pow(2).sum().item()
if t % 100 == 99:
print(t, loss)
#Backprop to compute gradients of a, b, c, d with respect to loss
grad_y_pred = 2.0 * (y_pred - y)
grad_a = grad_y_pred.sum()
grad_b = (grad_y_pred * x).sum()
grad_c = (grad_y_pred * x ** 2).sum()
grad_d = (grad_y_pred * x ** 3).sum()
#Update weights using gradient descent
a -= learning_rate * grad_a
b -= learning_rate * grad_b
c -= learning_rate * grad_c
d -= learning_rate * grad_d
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')
the torch can use the GPU,like the following code.
the output is: True True 99 1057.953369140625 199 732.12646484375 299 508.0294189453125 399 353.72833251953125 499 247.36788940429688 599 173.97360229492188 699 123.27410125732422 799 88.21514892578125 899 63.94678497314453 999 47.13111114501953 1099 35.46820831298828 1199 27.371381759643555 1299 21.745071411132812 1399 17.832063674926758 1499 15.10826301574707 1599 13.210683822631836 1699 11.887638092041016 1799 10.964457511901855 1899 10.319819450378418 1999 9.869365692138672 Result: y = 0.03163151070475578 + 0.8690047264099121 x + -0.005456964019685984 x^2 + -0.09507481753826141 x^3
进程已结束,退出代码为 0