PyTorch 的 torch
模块
torch
是 PyTorch 的核心模块,提供了张量操作、数学运算、设备管理等功能。以下是其主要类和功能的详细介绍。
1. 张量操作(Tensor Operations)
torch
提供了丰富的张量操作函数,用于创建、操作和转换张量。
常用函数
函数名 | 用途描述 |
---|
torch.tensor() | 创建张量。 |
torch.zeros() | 创建全零张量。 |
torch.ones() | 创建全一张量。 |
torch.rand() | 创建均匀分布随机张量。 |
torch.randn() | 创建标准正态分布随机张量。 |
torch.arange() | 创建等差序列张量。 |
torch.linspace() | 创建等间隔序列张量。 |
torch.eye() | 创建单位矩阵。 |
torch.cat() | 沿指定维度拼接张量。 |
torch.stack() | 沿新维度堆叠张量。 |
torch.split() | 将张量分割为多个子张量。 |
torch.reshape() | 改变张量形状。 |
torch.transpose() | 转置张量。 |
torch.matmul() | 矩阵乘法。 |
torch.sum() | 计算张量元素和。 |
torch.mean() | 计算张量元素均值。 |
torch.max() | 计算张量元素最大值。 |
torch.min() | 计算张量元素最小值。 |
2. 数学运算(Mathematical Operations)
torch
提供了丰富的数学运算函数。
常用函数
函数名 | 用途描述 |
---|
torch.add() | 张量加法。 |
torch.sub() | 张量减法。 |
torch.mul() | 张量乘法(逐元素)。 |
torch.div() | 张量除法(逐元素)。 |
torch.pow() | 张量幂运算。 |
torch.sqrt() | 张量平方根。 |
torch.exp() | 张量指数运算。 |
torch.log() | 张量对数运算。 |
torch.sin() | 张量正弦函数。 |
torch.cos() | 张量余弦函数。 |
torch.tanh() | 张量双曲正切函数。 |
torch.abs() | 张量绝对值。 |
torch.clamp() | 将张量元素限制在指定范围内。 |
3. 设备管理(Device Management)
torch
提供了设备管理功能,支持在 CPU 和 GPU 之间切换。
常用函数
函数名 | 用途描述 |
---|
torch.cuda.is_available() | 检查 GPU 是否可用。 |
torch.device() | 指定设备(如 'cuda' 或 'cpu' )。 |
torch.to() | 将张量或模型移动到指定设备。 |
torch.cuda.empty_cache() | 清空 GPU 缓存。 |
4. 自动微分(Autograd)
torch
的 autograd
模块支持自动微分,用于计算梯度。
常用函数
函数名 | 用途描述 |
---|
torch.tensor(requires_grad=True) | 创建需要计算梯度的张量。 |
torch.backward() | 计算梯度。 |
torch.grad() | 计算指定变量的梯度。 |
torch.no_grad() | 禁用梯度计算(用于推理或冻结参数)。 |
torch.detach() | 返回一个不需要梯度的新张量。 |
5. 随机数生成(Random Number Generation)
torch
提供了随机数生成功能。
常用函数
函数名 | 用途描述 |
---|
torch.manual_seed() | 设置随机种子。 |
torch.rand() | 生成均匀分布随机数。 |
torch.randn() | 生成标准正态分布随机数。 |
torch.randint() | 生成整数随机数。 |
torch.randperm() | 生成随机排列。 |
6. 文件操作(File Operations)
torch
提供了模型和张量的保存与加载功能。
常用函数
函数名 | 用途描述 |
---|
torch.save() | 保存模型或张量。 |
torch.load() | 加载模型或张量。 |
torch.load_state_dict() | 加载模型参数。 |
torch.save_state_dict() | 保存模型参数。 |
7. 其他功能
torch
还提供了许多其他功能,如分布式训练、FFT、稀疏张量等。
常用模块
模块名 | 用途描述 |
---|
torch.distributed | 分布式训练支持。 |
torch.fft | 快速傅里叶变换。 |
torch.sparse | 稀疏张量支持。 |
torch.jit | 模型脚本化和优化。 |
torch.onnx | 将模型导出为 ONNX 格式。 |
示例代码
1. 创建张量
1 2 3 4 5 6
| import torch
x = torch.tensor([1, 2, 3]) y = torch.zeros(2, 3) z = torch.rand(3, 3)
|
2. 数学运算
1 2 3 4 5 6 7 8
| a = torch.tensor([1, 2, 3]) b = torch.tensor([4, 5, 6])
c = torch.add(a, b)
d = torch.matmul(a.unsqueeze(0), b.unsqueeze(1))
|
3. 自动微分
1 2 3 4
| x = torch.tensor(2.0, requires_grad=True) y = x**2 + 3*x + 1 y.backward() print(x.grad)
|
4. 设备管理
1 2
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") x = torch.tensor([1, 2, 3]).to(device)
|
5. 保存与加载模型
1 2 3 4 5
| torch.save(model.state_dict(), "model.pth")
model.load_state_dict(torch.load("model.pth"))
|