-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
99 lines (78 loc) · 3.78 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import os
from datasets import DATASETS
from models import MODELS
from transforms import *
from utils import Timer, Counter, save_checkpoint, load_checkpoint, calculate_eta
def main(args):
device = torch.device(args.device)
logger = SummaryWriter(os.path.join(args.logs_dir, "tensorboard_log{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now())))
if args.dataset not in DATASETS:
raise Exception(f'`--dataset` is invalid. it should be one of {list(DATASETS.keys())}')
train_data = DATASETS[args.dataset](args.train_root,
transforms=Compose([RandomCrop(224),
RandomHorizontalFlip(),
RandomVerticalFlip(),
ToTensor(),
Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
**args.__dict__)
train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers)
do_val = False
if args.val_root:
val_data = DATASETS[args.dataset](args.val_root, transforms=ToTensor(), **args.__dict__)
val_loader = DataLoader(val_data, batch_size=args.batch_size, num_workers=args.num_workers)
do_val = True
net = MODELS[args.model_name](pretrained=args.resume is None, **args.__dict__).to(device)
optimizer = torch.optim.SGD(params=net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4)
start_epoch, total_epoch, global_step = 0, args.epochs, 0
if args.resume is not None:
checkpoint = load_checkpoint(args.resume)
net.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
global_step = checkpoint['global_step']
print(f"=> Start epoch {start_epoch} ")
for epoch in range(start_epoch, total_epoch):
net.train()
timer, counter = Timer(), Counter()
timer.start()
for step, (img, label) in enumerate(train_loader):
img, label = img.to(device), label.to(device)
reader_time = timer.elapsed_time()
loss, miou = net(img, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss = float(loss)
batch_time = timer.elapsed_time()
counter.append(loss=loss, miou=miou, reader_time=reader_time, batch_time=batch_time)
eta = calculate_eta(len(train_loader) - step, counter.batch_time)
print(f"[epoch={epoch + 1}/{total_epoch}] "
f"step={step + 1}/{len(train_loader)} "
f"loss={loss:.4f}/{counter.loss:.4f} "
f"miou={miou:.4f}/{counter.miou:.4f} "
f"batch_time={counter.batch_time:.4f} "
f"reader_time={counter.reader_time:.4f} "
f"| ETA {eta}",
end="\r",
flush=True)
logger.add_scalar("loss", float(loss), global_step=global_step)
logger.add_scalar("miou", float(miou), global_step=global_step)
global_step += 1
logger.flush()
timer.restart()
print()
save_checkpoint({
'net': net.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch + 1,
'global_step': global_step,
}, epoch + 1, False, save_dir=args.save_dir)
pass
pass
if __name__ == '__main__':
from options.train import parse_args
main(parse_args())