-
Notifications
You must be signed in to change notification settings - Fork 52
/
Copy pathutils.py
32 lines (25 loc) · 1.16 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import tensorflow as tf
import numpy as np
import pickle
import json
import os
#from collections import defaultdict
#from scipy import ndimage
def flatten_tf_array(array):
shape = array.get_shape().as_list()
return tf.reshape(array, [shape[0], shape[1] * shape[2] * shape[3]])
def accuracy(predictions, labels):
return (100.0 * np.sum(np.argmax(predictions, 1) == np.argmax(labels, 1)) / predictions.shape[0])
def randomize(dataset, labels):
permutation = np.random.permutation(labels.shape[0])
shuffled_dataset = dataset[permutation, :, :]
shuffled_labels = labels[permutation]
return shuffled_dataset, shuffled_labels
def one_hot_encode(np_array, num_labels):
return (np.arange(num_labels) == np_array[:,None]).astype(np.float32)
def reformat_data(dataset, labels, image_width, image_height, image_depth):
np_dataset_ = np.array([np.array(image_data).reshape(image_width, image_height, image_depth) for image_data in dataset])
num_labels = len(np.unique(labels))
np_labels_ = one_hot_encode(np.array(labels, dtype=np.float32), num_labels)
np_dataset, np_labels = randomize(np_dataset_, np_labels_)
return np_dataset, np_labels