pytorch基础--torch

本文最后更新于 2025年1月26日 晚上

PyTorch 的 torch 模块

torch 是 PyTorch 的核心模块,提供了张量操作、数学运算、设备管理等功能。以下是其主要类和功能的详细介绍。


1. 张量操作(Tensor Operations)

torch 提供了丰富的张量操作函数,用于创建、操作和转换张量。

常用函数

函数名用途描述
torch.tensor()创建张量。
torch.zeros()创建全零张量。
torch.ones()创建全一张量。
torch.rand()创建均匀分布随机张量。
torch.randn()创建标准正态分布随机张量。
torch.arange()创建等差序列张量。
torch.linspace()创建等间隔序列张量。
torch.eye()创建单位矩阵。
torch.cat()沿指定维度拼接张量。
torch.stack()沿新维度堆叠张量。
torch.split()将张量分割为多个子张量。
torch.reshape()改变张量形状。
torch.transpose()转置张量。
torch.matmul()矩阵乘法。
torch.sum()计算张量元素和。
torch.mean()计算张量元素均值。
torch.max()计算张量元素最大值。
torch.min()计算张量元素最小值。

2. 数学运算(Mathematical Operations)

torch 提供了丰富的数学运算函数。

常用函数

函数名用途描述
torch.add()张量加法。
torch.sub()张量减法。
torch.mul()张量乘法(逐元素)。
torch.div()张量除法(逐元素)。
torch.pow()张量幂运算。
torch.sqrt()张量平方根。
torch.exp()张量指数运算。
torch.log()张量对数运算。
torch.sin()张量正弦函数。
torch.cos()张量余弦函数。
torch.tanh()张量双曲正切函数。
torch.abs()张量绝对值。
torch.clamp()将张量元素限制在指定范围内。

3. 设备管理(Device Management)

torch 提供了设备管理功能,支持在 CPU 和 GPU 之间切换。

常用函数

函数名用途描述
torch.cuda.is_available()检查 GPU 是否可用。
torch.device()指定设备(如 'cuda''cpu')。
torch.to()将张量或模型移动到指定设备。
torch.cuda.empty_cache()清空 GPU 缓存。

4. 自动微分(Autograd)

torchautograd 模块支持自动微分,用于计算梯度。

常用函数

函数名用途描述
torch.tensor(requires_grad=True)创建需要计算梯度的张量。
torch.backward()计算梯度。
torch.grad()计算指定变量的梯度。
torch.no_grad()禁用梯度计算(用于推理或冻结参数)。
torch.detach()返回一个不需要梯度的新张量。

5. 随机数生成(Random Number Generation)

torch 提供了随机数生成功能。

常用函数

函数名用途描述
torch.manual_seed()设置随机种子。
torch.rand()生成均匀分布随机数。
torch.randn()生成标准正态分布随机数。
torch.randint()生成整数随机数。
torch.randperm()生成随机排列。

6. 文件操作(File Operations)

torch 提供了模型和张量的保存与加载功能。

常用函数

函数名用途描述
torch.save()保存模型或张量。
torch.load()加载模型或张量。
torch.load_state_dict()加载模型参数。
torch.save_state_dict()保存模型参数。

7. 其他功能

torch 还提供了许多其他功能,如分布式训练、FFT、稀疏张量等。

常用模块

模块名用途描述
torch.distributed分布式训练支持。
torch.fft快速傅里叶变换。
torch.sparse稀疏张量支持。
torch.jit模型脚本化和优化。
torch.onnx将模型导出为 ONNX 格式。

示例代码

1. 创建张量

1
2
3
4
5
6
import torch

# 创建张量
x = torch.tensor([1, 2, 3])
y = torch.zeros(2, 3) # 2x3 全零张量
z = torch.rand(3, 3) # 3x3 随机张量

2. 数学运算

1
2
3
4
5
6
7
8
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])

# 加法
c = torch.add(a, b)

# 矩阵乘法
d = torch.matmul(a.unsqueeze(0), b.unsqueeze(1))

3. 自动微分

1
2
3
4
x = torch.tensor(2.0, requires_grad=True)
y = x**2 + 3*x + 1
y.backward()
print(x.grad) # 输出梯度值

4. 设备管理

1
2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = torch.tensor([1, 2, 3]).to(device)

5. 保存与加载模型

1
2
3
4
5
# 保存模型
torch.save(model.state_dict(), "model.pth")

# 加载模型
model.load_state_dict(torch.load("model.pth"))

pytorch基础--torch
https://jimes.cn/2025/01/26/pytorch的torch/
作者
Jimes
发布于
2025年1月26日
许可协议