You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
the torch can use the GPU,like the following code.
import torch
import math
#this ensures that the current MacOS version is at least 12.3+
print(torch.backends.mps.is_available())
#this ensures that the current current PyTorch installation was built with MPS activated.
print(torch.backends.mps.is_built())
dtype = torch.float
device = torch.device("mps")
#Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)
#Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)
learning_rate = 1e-6
for t in range(2000):
# Forward pass: compute predicted y
y_pred = a + b * x + c * x ** 2 + d * x ** 3
#Compute and print loss
loss = (y_pred - y).pow(2).sum().item()
if t % 100 == 99:
print(t, loss)
#Backprop to compute gradients of a, b, c, d with respect to loss
grad_y_pred = 2.0 * (y_pred - y)
grad_a = grad_y_pred.sum()
grad_b = (grad_y_pred * x).sum()
grad_c = (grad_y_pred * x ** 2).sum()
grad_d = (grad_y_pred * x ** 3).sum()
#Update weights using gradient descent
a -= learning_rate * grad_a
b -= learning_rate * grad_b
c -= learning_rate * grad_c
d -= learning_rate * grad_d
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')
the torch can use the GPU,like the following code.
the output is:
True
True
99 1057.953369140625
199 732.12646484375
299 508.0294189453125
399 353.72833251953125
499 247.36788940429688
599 173.97360229492188
699 123.27410125732422
799 88.21514892578125
899 63.94678497314453
999 47.13111114501953
1099 35.46820831298828
1199 27.371381759643555
1299 21.745071411132812
1399 17.832063674926758
1499 15.10826301574707
1599 13.210683822631836
1699 11.887638092041016
1799 10.964457511901855
1899 10.319819450378418
1999 9.869365692138672
Result: y = 0.03163151070475578 + 0.8690047264099121 x + -0.005456964019685984 x^2 + -0.09507481753826141 x^3
进程已结束,退出代码为 0
Tasks
The text was updated successfully, but these errors were encountered: