Skip to content

Commit

Permalink
Training and Image Sorting Improvements (#4)
Browse files Browse the repository at this point in the history
*  You can now continue training a model on a new dataset regardless of the number of training classes. Previously you had to use same number of classes if you wanted to continue training.
* Automatic image sorting will now work on every image inside the target directory/folder. Previously images in sub directories were ignored.
  • Loading branch information
ProGamerGov committed Sep 8, 2020
1 parent 17c42ee commit fdba70d
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 35 deletions.
66 changes: 34 additions & 32 deletions data_tools/sort_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,40 +70,42 @@ def main_func(params):
print('Sorting images into the following classes:')
print(' ',class_list)
ext = [".jpg", ".jpeg", ".png", ".tiff", ".bmp"]
for file in os.listdir(params.unsorted_data):
if os.path.splitext(file)[1].lower() in ext:
test_input = preprocess(os.path.join(params.unsorted_data, file), params.data_mean, params.data_sd).to(params.use_device)
output = cnn(test_input)
output = output[0] if type(output) == tuple else output
index = output.argmax().item()

if params.confidence_min != -1 or params.confidence_max != -1:
sm = torch.nn.Softmax(dim=1)
probabilities = sm(output)[0][index]
if params.confidence_min != -1 and params.confidence_max != -1:
confident = True if params.confidence_min < probabilities and probabilities < params.confidence_max else False

elif params.confidence_min != -1:
confident = True if params.confidence_min < probabilities else False
elif params.confidence_max != -1:
confident = True if params.confidence_max > probabilities else False
else:
confident = True

if index == params.cat and confident or params.cat == -1 and confident:
if params.cat != -1 and index == params.cat:
new_path = os.path.join(params.sorted_data, str(params.cat))

for current_path, dirs, files in os.walk(params.unsorted_data, topdown=True):
for file in files:
if os.path.splitext(file)[1].lower() in ext:
test_input = preprocess(os.path.join(current_path, file), params.data_mean, params.data_sd).to(params.use_device)
output = cnn(test_input)
output = output[0] if type(output) == tuple else output
index = output.argmax().item()

if params.confidence_min != -1 or params.confidence_max != -1:
sm = torch.nn.Softmax(dim=1)
probabilities = sm(output)[0][index]
if params.confidence_min != -1 and params.confidence_max != -1:
confident = True if params.confidence_min < probabilities and probabilities < params.confidence_max else False

elif params.confidence_min != -1:
confident = True if params.confidence_min < probabilities else False
elif params.confidence_max != -1:
confident = True if params.confidence_max > probabilities else False
else:
if params.class_strings == '':
new_path = os.path.join(params.sorted_data, str(index))
else:
new_path = os.path.join(params.sorted_data, class_strings[index])
print(index, file)
confident = True

try:
shutil.copy2(os.path.join(os.path.normpath(params.unsorted_data), file), os.path.join(new_path, file))
except (OSError, SyntaxError) as oe:
print('Failed:', os.path.join(os.path.normpath(params.unsorted_data), file))
if index == params.cat and confident or params.cat == -1 and confident:
if params.cat != -1 and index == params.cat:
new_path = os.path.join(params.sorted_data, str(params.cat))
else:
if params.class_strings == '':
new_path = os.path.join(params.sorted_data, str(index))
else:
new_path = os.path.join(params.sorted_data, class_strings[index])
print(index, file)

try:
shutil.copy2(os.path.join(os.path.normpath(current_path), file), os.path.join(new_path, file))
except (OSError, SyntaxError) as oe:
print('Failed:', os.path.join(os.path.normpath(current_path), file))



Expand Down
15 changes: 12 additions & 3 deletions utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,23 @@ def setup_model(model_file='pt_bvlc.pth', num_classes=120, base_model='bvlc', pr

# Load checkpoint
def load_checkpoint(cnn, model_file, optimizer, lrscheduler, num_classes, device='cuda:0', is_start_model=True):
start_epoch = 1
start_epoch, change_fc = 1, False

checkpoint = torch.load(model_file, map_location='cpu')
if type(checkpoint) == dict:
model_keys = list(checkpoint.keys())

if not is_start_model:
cnn.replace_fc(num_classes, True)
try:
load_classes = checkpoint['num_classes']
if num_classes != load_classes:
is_start_model, change_fc = True, True # Pretend it's a starting model so FC gets replaced
else:
load_classes = num_classes
except:
load_classes = num_classes

if not is_start_model or change_fc:
cnn.replace_fc(load_classes, True)

cnn.load_state_dict(checkpoint['model_state_dict'])

Expand Down

0 comments on commit fdba70d

Please sign in to comment.