forked from elijahcole/caltech-ee148-spring2020-hw01
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_predictions.py
83 lines (60 loc) · 2.24 KB
/
run_predictions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import numpy as np
import json
from PIL import Image
def detect_red_light(I):
'''
This function takes a numpy array <I> and returns a list <bounding_boxes>.
The list <bounding_boxes> should have one element for each red light in the
image. Each element of <bounding_boxes> should itself be a list, containing
four integers that specify a bounding box: the row and column index of the
top left corner and the row and column index of the bottom right corner (in
that order). See the code below for an example.
Note that PIL loads images in RGB order, so:
I[:,:,0] is the red channel
I[:,:,1] is the green channel
I[:,:,2] is the blue channel
'''
bounding_boxes = [] # This should be a list of lists, each of length 4. See format example below.
'''
BEGIN YOUR CODE
'''
'''
As an example, here's code that generates between 1 and 5 random boxes
of fixed size and returns the results in the proper format.
'''
box_height = 8
box_width = 6
num_boxes = np.random.randint(1,5)
for i in range(num_boxes):
(n_rows,n_cols,n_channels) = np.shape(I)
tl_row = np.random.randint(n_rows - box_height)
tl_col = np.random.randint(n_cols - box_width)
br_row = tl_row + box_height
br_col = tl_col + box_width
bounding_boxes.append([tl_row,tl_col,br_row,br_col])
'''
END YOUR CODE
'''
for i in range(len(bounding_boxes)):
assert len(bounding_boxes[i]) == 4
return bounding_boxes
# set the path to the downloaded data:
data_path = './data/RedLights2011_Medium'
# set a path for saving predictions:
preds_path = '../data/hw01_preds'
os.makedirs(preds_path,exist_ok=True) # create directory if needed
# get sorted list of files:
file_names = sorted(os.listdir(data_path))
# remove any non-JPEG files:
file_names = [f for f in file_names if '.jpg' in f]
preds = {}
for i in range(len(file_names)):
# read image using PIL:
I = Image.open(os.path.join(data_path,file_names[i]))
# convert to numpy array:
I = np.asarray(I)
preds[file_names[i]] = detect_red_light(I)
# save preds (overwrites any previous predictions!)
with open(os.path.join(preds_path,'preds.json'),'w') as f:
json.dump(preds,f)