Pytorch迁移学习训练自己的图像分类模型【两天搞定AI毕设】_哔哩哔哩_bilibili

B1 环境安装

基本库

用到的一些基本库

pip install numpy scipy scikit-learn pandas matplotlib jupyter seaborn plotly requests tqdm opencv-python wandb pillow 

PyTorch + CUDA

要安装对应的版本,需要到官网进行安装选择。注意貌似是要 python ≥ 3.8 的版本。

PyTorch

Untitled

用这段代码检测 CUDA。


B2 torchvision.transforms 图像预处理

# ImageNet 上数据集 RGB 的均值 0.485, 0.456, 0.406;方差 0.229, 0.224, 0.225
IMAGENET_IMG_MU = [0.485, 0.456, 0.406]
IMAGENET_IMG_SIGMA = [0.229, 0.224, 0.225]

transform_train = torchvision.transforms.Compose([
    # 随机长宽比裁剪子图,得到分辨率为 224 * 224
    torchvision.transforms.RandomResizedCrop(224),
    # 有 0.5 的概率进行水平翻转
    torchvision.transforms.RandomHorizontalFlip(),
    # 转换为 tensor,并且归一化至 [0, 1](方法是直接除以 255)
    torchvision.transforms.ToTensor(),
    # 数据标准化
    torchvision.transforms.Normalize(IMAGENET_IMG_MU, IMAGENET_IMG_SIGMA),
])

transform_test = torchvision.transforms.Compose([
    # 将图像分辨率转小后,采集中心 0.875 的部分
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
    # 转换为 tensor,并且归一化至 [0, 1](方法是直接除以 255)
    torchvision.transforms.ToTensor(),
    # 数据标准化
    torchvision.transforms.Normalize(IMAGENET_IMG_MU, IMAGENET_IMG_SIGMA),
])

一些常见的 torchvision.transforms 变换:

torchvision的使用(transforms用法介绍)_torchvision.transforms_迷雾总会解的博客-CSDN博客

Illustration of transforms — Torchvision 0.15 documentation