0%

代码学习

一、前言

本篇博文用于记录学习他人论文复现代码中一些库的使用以及深度学习框架pytorch的一些语法。

二、有用的库

1.setuptools打包工具

是python标准的打包分发工具,可以将编写的python项目打包安装,使其它人可以像调用标准库或python第三方库一样直接使用。

setup.py定义打包程序的一些信息

1
2
3
4
5
6
7
8
9
10
11
setup(
name="dconv", #应用名
version="0.1", #版本号
author="fmassa", #作者
url="https://github.com/facebookresearch/maskrcnn-benchmark", #程序的官网地址
description="object detection in pytorch",
packages=find_packages(exclude=("configs", "tests",)), #需要处理的包目录
# install_requires=requirements,
ext_modules=get_extensions(),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
)

参数介绍:python的构建工具setup.py - 人生苦短,python当歌 - 博客园 (cnblogs.com)

执行 python setup.py XXX

生成目录格式如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
.
├── DemoApp.egg-info Egg相关信息
│   ├── PKG-INFO
│   ├── SOURCES.txt
│   ├── dependency_links.txt
│   └── top_level.txt
├── build build后文件
│   ├── bdist.macosx-10.14-intel
│   └── lib
│   └── demoapp
│   └── __init__.py
├── demoapp 源文件
│   └── __init__.py
├── dist
│   └── DemoApp-0.0-py2.7.egg 应用名-版本号-py版本.egg
└── setup.py

python setup.py build操作:如果软件包含C扩展名或定义了一些自定义的编译任务,它们也将被编译。如果只含python文件,复制全部是build。

2.argparse.ArgumentParser()用法

argparse是一个python模块:命令行选项、参数和子命令解析器。帮助定义程序使用参数、解析参数、自动生成帮助和使用手册与报错。

使用:创建argparse.ArgumentParser类–>>添加参数–>>解析参数

创建argparse.ArgumentParser类:

1
2
3
4
5
6
7
#example:
parser = argparse.ArgumentParser(description='Train image model with cross entropy loss')
#usage:
ArgumentParser对象
prog - 程序的名称(默认: sys.argv[0],prog猜测是programma的缩写)
usage - 描述程序用途的字符串(默认值:从添加到解析器的参数生成)
description - 在参数帮助文档之后显示的文本 (默认值:无)

add_argument()添加参数:

1
2
3
4
5
6
7
8
9
10
11
12
#example:
parser.add_argument('-d', '--dataset', type=str, default='miniImageNet_load')
parser.add_argument('--root', type=str, default='/miniImageNet_pickle')
#usage:
add_argument()方法:
name or flags - 一个命名或者一个选项字符串的列表
action - 表示该选项要执行的操作
default - 当参数未在命令行中出现时使用的值
dest - 用来指定参数的位置
type - 为参数类型,例如int
choices - 用来选择输入参数的范围。例如choice = [1, 5, 10], 表示输入参数只能为1,5 或10
help - 用来描述这个选项的作用

更具体的参数:python add_argument()用法解析 - 灰信网(软件开发博客聚合) (freesion.com)

解析参数parser.parse_args():

1
2
args = parser.parse_args()
使用参数:args.参数名

使用小tips:

1
2
parser = argument_parser()	#在argument_parser中获取参数
args = parser.parse_args()

3.os模块的常见使用

链接:Python OS 文件/目录方法 | 菜鸟教程 (runoob.com)

常用:

1
2
os.path.join('path1', 'path2'....,'file_name')	合并路径与文件
os.mkdir(path, mode=) 类似于mkdir

三、pytorch语法

1.随机初始化种子

1
2
3
torch.manual_seed(SEED)				#CPU设置随机种子
torch.cuda.manual_seed(SEED) #当前GPU的随机种子
torch.cuda.manual_seed_all(SEED) #所有GPU的随机种子

神经网络的参数初始化是随机的,如果想保证论文成果可复现,应使用初始化种子保证每次随机初始化参数的一致性。

2.Dataset和DataLoader

pytorch中加载数据的顺序:创建一个dataset对象–>>创建一个dataloader对象–>>迭代dataloader,获取训练/测试数据

Dataset类以torch.utils.data.dataset作为基类,需要完成__int__、 __len__、__getitem__三个函数的重载

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
import torch.utils.data.dataset as Dataset
import numpy as np

#创建子类
class subDataset(Dataset.Dataset):
#初始化,定义数据内容和标签
def __init__(self, Data, Label):
self.Data = Data
self.Label = Label
#返回数据集大小
def __len__(self):
return len(self.Data)
#得到数据内容和标签
def __getitem__(self, index):
data = torch.Tensor(self.Data[index])
label = torch.Tensor(self.Label[index])
return data, label

Dataloader参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class DataLoader(object):
"""
Data loader. Combines a dataset and a sampler, and provides
single- or multi-process iterators over the dataset.

Arguments:
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: 1).
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: False).
sampler (Sampler, optional): defines the strategy to draw samples from
the dataset. If specified, ``shuffle`` must be False.
batch_sampler (Sampler, optional): like sampler, but returns a batch of
indices at a time. Mutually exclusive with batch_size, shuffle,
sampler, and drop_last.
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means that the data will be loaded in the main process.
(default: 0)
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
into CUDA pinned memory before returning them.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: False)
timeout (numeric, optional): if positive, the timeout value for collecting a batch
from workers. Should always be non-negative. (default: 0)
worker_init_fn (callable, optional): If not None, this will be called on each
worker subprocess with the worker id as input, after seeding and before data
loading. (default: None)
"""

def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,timeout=0, worker_init_fn=None):

def __iter__(self):
return DataLoaderIter(self)

def __len__(self):
return len(self.batch_sampler)

较为复杂,现在不是很熟。不太写得出来,放个别人的链接

(40条消息) PyTorch源码解读之torch.utils.data.DataLoader_AI之路-CSDN博客_torch.utils.data