import torch
# 原始评分矩阵 5x4
A = torch.tensor([[5, 3, 0, 1],
[4, 0, 0, 1],
[1, 1, 0, 5],
[1, 0, 0, 4],
[0, 1, 5, 4]], dtype=torch.float32)
# 1.对矩阵A进行SVD分解
U, sigma, Vt = torch.svd(A)
# 打印结果
print("U:", U.shape) # torch.Size([5, 4]) 为什么不是 5x5 ?
print("S:", sigma.shape) # S: torch.Size([4])
print("V:", Vt.shape) # V: torch.Size([4, 4])
# 2.构造奇异值矩阵
Sigma = torch.diag(sigma)
# 3.重构原始矩阵
reconstructed_A = torch.mm(torch.mm(U, Sigma), Vt.t())
print("Original matrix A:")
print(A)
print("\nReconstructed matrix A:")
print(reconstructed_A)
# torch.svd_lowrank ()是PyTorch中的一个函数,用于计算矩阵的低秩奇异值分解(Low-rank Singular Value Decomposition,简称LSVD)。
import torch
# 创建一个随机矩阵 (复用 A)
# 计算低秩 SVD
U, S, V = torch.svd_lowrank(A, q=4) # 不配置q 会报错 ,why ?
# 打印结果
print("U:", U.shape) # U: torch.Size([5, 4])
print("S:", S.shape) # S: torch.Size([4])
print("V:", V.shape) # V: torch.Size([4, 4])
# 重构矩阵
A_reconstructed = U @ torch.diag(S) @ V.T
# 打印重构矩阵
print("Original A:\n", A)
print("Reconstructed A:\n", A_reconstructed)
"""
torch.svd:
计算复杂度: 计算完整 SVD,时间复杂度为 (O(mn^2)),适用于较小的矩阵。
适用场景: 适用于需要精确 SVD 的场景,特别是在矩阵规模较小时。
torch.svd_lowrank:
计算复杂度: 计算低秩 SVD,时间复杂度较低,适用于大规模矩阵。
适用场景: 适用于大规模矩阵的近似 SVD,特别是在需要减少计算复杂度和内存使用时。
"""
0条评论