%matplotlib inline
%reload_ext autoreload
%autoreload 2
from fastai.conv_learner import *
torch.backends.cudnn.benchmark=True
import fastText as ft
import torchvision.transforms as transforms
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
tfms = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
fname = 'images/waffles/971843.jpg'
PATH = Path('/data/food-101')
TMP_PATH = PATH/'tmp'
TMP_PATH.mkdir(exist_ok=True)
img = Image.open(PATH/fname)
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7f2cc810c080>
import fastai
arch=resnet34
ttfms,vtfms = tfms_from_model(arch, 224, transforms_side_on, max_zoom=1.1)
def to_array(x,y): return np.array(x).astype(np.float32)/255,None
def TT(x,y): return torch.from_numpy(x),None
ttfms.tfms = [to_array] + ttfms.tfms# + [TT]
ttfms(img)
array([[[ 0.96785, 0.96421, 0.96502, ..., 2.24891, 2.24891, 2.24891], [ 0.9785 , 0.9944 , 0.98907, ..., 2.24317, 2.23908, 2.24891], [ 0.97824, 0.97529, 0.97648, ..., 2.24891, 2.24869, 2.24891], ..., [-0.75144, -0.86714, -0.96067, ..., -1.30119, -1.29639, -1.31187], [-0.65392, -0.78303, -0.90652, ..., -1.27335, -1.27576, -1.26882], [-0.95539, -1.00646, -1.05198, ..., -1.27215, -1.27553, -1.26904]], [[ 1.71638, 1.71478, 1.71518, ..., 1.9099 , 1.89143, 1.85349], [ 1.70217, 1.72155, 1.71532, ..., 1.99923, 1.90453, 1.85487], [ 1.67606, 1.67971, 1.67864, ..., 1.96195, 1.89924, 1.85116], ..., [-1.18751, -1.28125, -1.34907, ..., -1.44528, -1.43596, -1.46097], [-1.12796, -1.21854, -1.29414, ..., -1.44168, -1.44 , -1.44183], [-1.31508, -1.38547, -1.43945, ..., -1.46204, -1.4649 , -1.45775]], [[ 2.42496, 2.42334, 2.42356, ..., 1.16553, 1.11424, 1.0231 ], [ 2.42815, 2.45321, 2.4449 , ..., 1.28741, 1.18724, 1.10634], [ 2.43509, 2.43959, 2.43773, ..., 1.23707, 1.16099, 1.07339], ..., [-1.05983, -1.19064, -1.2806 , ..., -1.36088, -1.35557, -1.37079], [-0.96328, -1.10416, -1.21674, ..., -1.35046, -1.3494 , -1.35049], [-1.26429, -1.3345 , -1.38626, ..., -1.35884, -1.35839, -1.36259]]], dtype=float32)
ft_vecs = ft.load_model(str((PATH/'wiki.en.bin')))
# the English Wikipedia (from which these fasttext vectors are based) contains many imported words,
# such as this South Korean delicacy.
ft_vecs.get_word_vector('bibimbap')
array([-0.17312, 0.67789, 0.02758, -0.19625, -0.12882, 0.02245, -0.18145, -0.78455, 0.03599, -0.4044 , 0.03361, 0.1978 , -0.10437, -0.50488, 0.32577, -0.21824, -0.48772, 0.24602, 0.30623, -0.15783, -0.30381, -0.3505 , 0.19436, 0.35325, -0.11009, 0.27115, -0.54712, -0.38872, 0.20073, -0.24999, -0.05064, -0.33633, -0.45624, 0.19107, -0.04186, 0.34411, 0.11012, 0.10668, -0.01386, -0.32704, -0.08347, 0.41185, 0.64833, -0.14697, 0.03427, -0.02415, 0.12514, -0.16299, -0.40224, 0.60419, -0.09834, 0.02416, 0.22191, 0.27945, -0.06426, 0.49234, -0.1661 , 0.07503, 0.32363, -0.22936, 0.48197, 0.31985, 0.10119, -0.53915, -0.05682, -0.09341, 0.11307, 0.09115, 0.515 , 0.13764, -0.22909, -0.17249, 0.59863, 0.09867, -0.13623, 0.27125, 0.27687, 0.57246, -0.45064, 0.2504 , 0.06641, 0.09382, 0.15634, -0.38996, -0.06068, -0.2496 , -0.18251, 0.30635, 0.62988, -0.0858 , 0.44684, -0.38676, 0.15667, -0.56013, 0.57022, -0.29776, -0.05273, 0.21835, 0.59727, -0.43789, 0.12256, -0.30611, -0.37125, 0.07134, 0.18511, 0.54203, -0.19928, 0.50199, -0.07429, 0.01258, -0.03046, -0.26513, 0.09646, -0.14024, -0.33365, -0.26094, 0.26558, -0.19574, -0.01899, 0.47574, 0.20377, -0.02935, 0.01797, 0.39284, 0.18347, 0.09497, -0.24858, -0.4553 , 0.09672, 0.99625, 0.35264, -0.02869, 0.08237, 0.16365, -0.22485, -0.11307, -0.25275, -0.3289 , -0.4505 , -0.30572, 0.29358, 0.16704, 0.12084, 0.16699, 0.27638, 0.58782, -0.06847, 0.00668, -0.3894 , 0.10499, 0.37805, -0.00994, 0.22307, 0.46086, 0.08499, -0.05885, 0.2862 , -0.20062, 0.36248, 0.12401, 0.43504, 0.085 , -0.23846, -0.1597 , 0.10513, 0.12748, -0.72129, -0.49533, 0.6134 , 0.02878, -0.45184, -0.30724, -0.3441 , 0.07802, -0.0218 , 0.23205, 0.07547, -0.38557, -0.01962, 0.18933, -0.18683, 0.13515, 0.2464 , -0.05562, -0.37576, -0.12637, -0.07848, -0.31799, 0.61322, 0.3874 , 0.08806, -0.31721, 0.28414, 0.58549, -0.75648, 0.02487, -0.06267, 0.30738, -0.62018, 0.16412, -0.06167, 0.028 , 0.58349, -0.14395, 0.45287, -0.30723, 0.05226, 0.02623, 0.33187, -0.03541, 1.03115, 0.02051, 0.22106, -0.10503, 0.15577, -0.14716, -0.10544, 0.11361, -0.192 , 0.13184, 0.24489, 0.21402, -0.07268, -0.16517, -0.44544, -0.15136, -0.36431, 0.1427 , -0.06281, -0.00214, 0.45211, -0.09094, -0.46689, 0.06809, -0.20073, -0.4712 , -0.04248, -0.26233, -0.19474, -0.10058, 0.28576, 0.46258, -0.26795, 0.35431, -0.33528, -0.02054, 0.32385, 0.89658, -0.12202, -0.73352, 0.01898, 0.06232, 0.0415 , 0.05949, 0.07843, 0.32239, 0.05867, -0.36756, 0.26816, 0.1976 , 0.00044, 0.21464, 0.02376, 0.05112, -0.39343, -0.32372, -0.01607, 0.17569, -0.14019, -0.15025, -0.28291, 0.00314, 0.41763, -0.09531, -0.02298, -0.10097, 0.01493, 0.07405, -0.0517 , 0.06373, 0.57262, 0.6205 , 0.35947, -0.11288, -0.4713 , -0.07622, 0.20206, 0.36815, -0.10552, 0.29709, 0.13807, -0.34253, 0.34833, -0.16579, -0.3794 , 0.52595, -0.12898, 0.37718, 0.10546, 0.10909], dtype=float32)
np.corrcoef(ft_vecs.get_word_vector('boat'), ft_vecs.get_word_vector('dinghy'))
array([[1. , 0.65917], [0.65917, 1. ]])
np.corrcoef(ft_vecs.get_word_vector('ketchup'), ft_vecs.get_word_vector('chilli'))
array([[1. , 0.61188], [0.61188, 1. ]])
ft_words = ft_vecs.get_words(include_freq=True)
ft_word_dict = {k:v for k,v in zip(*ft_words)}
ft_words = sorted(ft_word_dict.keys(), key=lambda x: ft_word_dict[x])
len(ft_words)
2519370
ft_word_dict['bibimbap']
130
from fastai.io import get_data
CLASSES_FN = PATH/'meta/classes.txt'
WORDS_FN = 'classids.txt'
get_data(f'http://files.fast.ai/data/{WORDS_FN}', PATH/WORDS_FN)
classes_101 = (TMP_PATH/CLASSES_FN).open().readlines()
nclass = len(classes_101); nclass
101
classes_101[:10]
['apple_pie\n', 'baby_back_ribs\n', 'baklava\n', 'beef_carpaccio\n', 'beef_tartare\n', 'beet_salad\n', 'beignets\n', 'bibimbap\n', 'bread_pudding\n', 'breakfast_burrito\n']
classes_101[0]
'apple_pie\n'
classid_lines = (PATH/WORDS_FN).open().readlines()
classid_lines[:5]
['n00001740 entity\n', 'n00001930 physical_entity\n', 'n00002137 abstraction\n', 'n00002452 thing\n', 'n00002684 object\n']
wordnet = [l.strip().split()[1] for l in classid_lines] + [l.strip() for l in classes_101]
classes_all = list(set(wordnet))
len(classes_all),len(classes_101)
(67962, 101)
wordnet[-1000]
'compassionate_leave'
lc_vec_d = {w.lower(): ft_vecs.get_word_vector(w) for w in classes_all}
lc_vec_d_1k = {w.lower().strip(): ft_vecs.get_word_vector(w) for w in classes_101}
word_wv = [(w, lc_vec_d[w.lower()]) for w in classes_all
if w.lower() in lc_vec_d]
word2wv = dict(word_wv)
pickle.dump(lc_vec_d, (TMP_PATH/'lc_vec_d.pkl').open('wb'))
pickle.dump(lc_vec_d_1k, (TMP_PATH/'lc_vec_d_1k.pkl').open('wb'))
lc_vec_d = pickle.load((TMP_PATH/'lc_vec_d.pkl').open('rb'))
lc_vec_d_1k = pickle.load((TMP_PATH/'lc_vec_d_1k.pkl').open('rb'))
images = []
img_vecs = []
for dire in (PATH/'images').iterdir():
if dire.name not in lc_vec_d: continue
vec = lc_vec_d[dire.name]
for f in dire.iterdir():
images.append(str(f.relative_to(PATH)))
img_vecs.append(vec)
images[0]
'images/beignets/2885220.jpg'
img_vecs[0]
array([-0.03161, 0.09492, -0.11248, -0.27354, -0.21677, 0.4291 , 0.1963 , -0.25837, 0.52544, -0.00154, 0.21663, 0.11508, -0.15019, -0.19484, -0.17093, -0.29794, 0.00695, -0.0194 , 0.21036, -0.12465, -0.36858, -0.05815, 0.39113, 0.455 , -0.10581, 0.21249, 0.02394, -0.44104, 0.20809, 0.04945, -0.47479, -0.27543, -0.60547, -0.11644, -0.01707, 0.18735, -0.20743, -0.18493, -0.29754, 0.09811, -0.11574, -0.03182, 0.10966, -0.00404, -0.12068, 0.14162, -0.15134, -0.34659, -0.11288, 0.65519, -0.27983, -0.35147, -0.2659 , -0.1281 , -0.12126, -0.01779, -0.17949, -0.02894, 0.02253, 0.14515, 0.49039, -0.03 , -0.04617, -0.311 , -0.13614, -0.20199, 0.1889 , 0.06929, 0.27933, 0.19121, 0.00664, 0.2085 , 0.40239, -0.08814, -0.15157, 0.70003, 0.27375, 0.12205, -0.04917, -0.31967, 0.20337, 0.11432, 0.03515, -0.18803, -0.25561, -0.09772, 0.21651, 0.3304 , 0.37034, 0.33194, 0.21715, 0.04101, 0.02677, -0.47338, 0.42778, 0.03258, -0.03556, 0.0261 , 0.23494, -0.00673, -0.19062, -0.28234, -0.10476, -0.08403, 0.35976, -0.00705, -0.02455, 0.09042, -0.15272, -0.3764 , -0.38272, -0.17568, -0.0931 , 0.3606 , -0.27679, -0.03736, 0.01043, -0.38535, 0.50184, 0.12553, 0.27623, 0.32785, 0.03588, 0.32066, -0.06906, 0.04528, 0.38887, 0.03629, 0.01272, 0.52859, -0.15317, 0.01514, -0.44457, 0.07255, 0.24553, -0.15377, -0.03104, -0.01438, -0.28156, 0.0108 , 0.02379, 0.21121, 0.32242, 0.31632, -0.00263, 0.51931, -0.2093 , -0.35699, 0.00492, 0.16245, 0.21134, -0.42873, -0.44965, -0.27367, 0.04797, -0.25307, 0.19076, -0.00241, 0.28667, 0.04986, -0.0186 , 0.23348, -0.14194, 0.30044, 0.25293, 0.24953, -0.21796, -0.29275, 0.95483, 0.04274, -0.43892, -0.39237, 0.15195, -0.06724, 0.12623, 0.42183, -0.22326, -0.07582, -0.15509, -0.17529, 0.23283, 0.02721, 0.46088, 0.0139 , -0.30143, 0.02938, -0.04182, -0.37551, -0.24746, -0.00211, 0.31085, -0.2042 , 0.21051, 0.66146, -0.39704, -0.2245 , 0.1762 , 0.37517, 0.02941, 0.22829, 0.10166, 0.34434, 0.30977, -0.11807, 0.2225 , -0.47804, -0.3854 , 0.11067, 0.09205, 0.15174, 0.38522, -0.09034, 0.05449, 0.06648, -0.0662 , 0.13788, 0.0016 , -0.12287, 0.19405, -0.11108, 0.31104, -0.00782, -0.36467, -0.17175, -0.14191, 0.15226, 0.00555, 0.18843, 0.26433, -0.36739, -0.03156, 0.1255 , -0.02819, 0.02782, -0.16926, -0.21894, -0.02851, 0.06039, -0.16733, 0.00587, -0.33398, -0.02439, 0.24347, 0.17343, -0.46496, 0.44183, 0.77414, 0.12719, -0.11309, 0.06603, 0.06252, 0.01696, 0.2589 , -0.35975, 0.07167, 0.01114, -0.05407, -0.32063, 0.19737, 0.04926, -0.07892, 0.47451, 0.15008, 0.05304, -0.16067, -0.22415, -0.37121, 0.51919, 0.0268 , 0.31568, -0.49628, -0.10531, 0.50396, -0.02201, -0.56316, -0.20446, -0.06742, 0.3241 , -0.01644, 0.18642, 0.01005, 0.35817, 0.00385, 0.2066 , -0.06989, 0.46378, 0.13806, -0.08049, -0.27685, 0.25909, -0.10811, -0.19301, -0.1227 , -0.09205, -0.22184, 0.22625, 0.03328, 0.50427, 0.22187, 0.28552], dtype=float32)
img_vecs = np.stack(img_vecs)
img_vecs.shape
(101000, 300)
pickle.dump(images, (TMP_PATH/'images.pkl').open('wb'))
pickle.dump(img_vecs, (TMP_PATH/'img_vecs.pkl').open('wb'))
images = pickle.load((TMP_PATH/'images.pkl').open('rb'))
img_vecs = pickle.load((TMP_PATH/'img_vecs.pkl').open('rb'))
arch = resnet34
n = len(images); n
101000
val_idxs = get_cv_idxs(n, seed=27)
val_idxs
array([20048, 62177, 56946, ..., 45000, 44824, 96268])
tfms = tfms_from_model(arch, 224, transforms_side_on, max_zoom=1.1)
md = ImageClassifierData.from_names_and_array(PATH, images, img_vecs, val_idxs=val_idxs,
classes=None, tfms=tfms, continuous=True, bs=256)
x,y = next(iter(md.val_dl))
models = ConvnetBuilder(arch, md.c, is_multi=False, is_reg=True, xtra_fc=[1024], ps=[0.5,0.5])
learn = ConvLearner(md, models, precompute=True)
learn.opt_fn = partial(optim.Adam, betas=(0.9,0.95))
def cos_loss(inp,targ): return 1 - F.cosine_similarity(inp,targ).mean()
learn.crit = cos_loss
learn.lr_find(start_lr=1e-14, end_lr=1e15, wds=0.02)
HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))
93%|█████████▎| 294/316 [00:40<00:02, 7.35it/s, loss=nan]
learn.sched.plot(10, 1)
lr = 1e-2
wd = 1e-7
learn.precompute=True
learn.fit(lr, 1, cycle_len=20, wds=wd, use_clr_beta=(10,20, 0.95, 0.85))
HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))
43%|████▎ | 136/316 [00:07<00:10, 17.38it/s, loss=0.383] epoch trn_loss val_loss 0 0.285495 0.259447 1 0.247486 0.218731 2 0.233475 0.203851 3 0.229529 0.199069 4 0.226356 0.196048 5 0.223297 0.192082 6 0.22051 0.192349 7 0.221322 0.18906 8 0.218742 0.187842 9 0.214297 0.185845 10 0.209901 0.185604 11 0.20772 0.180261 12 0.204038 0.178309 13 0.200772 0.178253 14 0.198641 0.17676 15 0.195599 0.175979 16 0.19615 0.176313 17 0.195417 0.176696 18 0.191627 0.175177 19 0.193891 0.174636
[array([0.17464])]
learn.save('p17_loss_1cycle_p001lr')
learn.load('p17_loss_1cycle_p001lr')
learn.bn_freeze(True)
learn.fit(lr, 1, cycle_len=20, wds=wd, use_clr=(10,20))
HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))
epoch trn_loss val_loss 0 0.190927 0.173675 1 0.195457 0.175798 2 0.19453 0.175052 3 0.192118 0.176636 4 0.191071 0.174743 5 0.189897 0.174653 6 0.186789 0.17064 7 0.185151 0.172911 8 0.18549 0.169044 9 0.182516 0.170845 10 0.182815 0.169598 11 0.181223 0.169691 12 0.17975 0.168432 13 0.178964 0.167944 14 0.17846 0.169226 15 0.176993 0.167805 16 0.179309 0.168353 17 0.17549 0.16711 18 0.175561 0.166816 19 0.17455 0.166416
[array([0.16642])]
learn.save('p16_loss_1cycle_p001lr')
lrs = np.array([lr/1000,lr/100,lr])
learn.precompute=False
learn.freeze_to(1)
learn.lr_find(start_lr=1e-4, end_lr=1e15)
HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))
72%|███████▏ | 229/316 [09:34<03:38, 2.51s/it, loss=nan]
learn.sched.plot()
learn.save('pre0')
learn.load('pre0')
syns, wvs = list(lc_vec_d_1k.keys()), list(lc_vec_d_1k.values())
wvs = np.array(wvs)
syns[0]
'apple_pie'
%time pred_wv = learn.predict()
CPU times: user 5min 7s, sys: 2.21 s, total: 5min 10s Wall time: 1min 21s
start=600
denorm = md.val_ds.denorm
def show_img(im, figsize=None, ax=None):
if not ax: fig,ax = plt.subplots(figsize=figsize)
ax.imshow(im)
ax.axis('off')
return ax
def show_imgs(ims, cols, figsize=None):
fig,axes = plt.subplots(len(ims)//cols, cols, figsize=figsize)
for i,ax in enumerate(axes.flat): show_img(ims[i], ax=ax)
plt.tight_layout()
show_imgs(denorm(md.val_ds[start:start+25][0]), 5, (10,10))
import nmslib
def create_index(a):
index = nmslib.init(space='angulardist')
index.addDataPointBatch(a)
index.createIndex()
return index
def get_knns(index, vecs):
return zip(*index.knnQueryBatch(vecs, k=10, num_threads=4))
def get_knn(index, vec): return index.knnQuery(vec, k=10)
nn_wvs = create_index(wvs)
idxs,dists = get_knns(nn_wvs, pred_wv)
[[syns[id] for id in ids[:3]] for ids in idxs[start:start+10]]
[['grilled_cheese_sandwich', 'pulled_pork_sandwich', 'hamburger'], ['hamburger', 'pizza', 'steak'], ['grilled_cheese_sandwich', 'pulled_pork_sandwich', 'garlic_bread'], ['hamburger', 'grilled_cheese_sandwich', 'bread_pudding'], ['hamburger', 'grilled_cheese_sandwich', 'pulled_pork_sandwich'], ['hamburger', 'chicken_quesadilla', 'bread_pudding'], ['pulled_pork_sandwich', 'grilled_cheese_sandwich', 'lobster_roll_sandwich'], ['hamburger', 'pizza', 'steak'], ['hamburger', 'chicken_quesadilla', 'bread_pudding'], ['hot_dog', 'grilled_cheese_sandwich', 'garlic_bread']]
all_syns, all_wvs = list(zip(*word2wv.items()))
all_wvs = np.array(all_wvs)
nn_allwvs = create_index(all_wvs)
idxs,dists = get_knns(nn_allwvs, pred_wv)
[[all_syns[id] for id in ids[:3]] for ids in idxs[start:start+10]]
[['hamburger_bun', 'chicken_sandwich', 'pork_sausage'], ['hamburger', 'Limburger', 'Luxemburger'], ['grilled_cheese_sandwich', 'pulled_pork_sandwich', 'garlic_bread'], ['hamburger_bun', 'hamburger', 'pork_sausage'], ['hamburger_bun', 'steak_sauce', 'chicken_sandwich'], ['hamburger', 'steak_sauce', 'hamburger_bun'], ['pulled_pork_sandwich', 'grilled_cheese_sandwich', 'pork_sausage'], ['hamburger', 'Limburger', 'Luxemburger'], ['hamburger', 'Limburger', 'hamburger_bun'], ['hot_dog', 'hot_sauce', 'grilled_cheese_sandwich']]
nn_predwv = create_index(pred_wv)
# en_vecd = pickle.load(open(TRANS_PATH/'wiki.en.pkl','rb'))
# vec = en_vecd['boat']
vec = ft_vecs.get_word_vector('boat')
idxs,dists = get_knn(nn_predwv, vec)
show_imgs([open_image(PATH/md.val_ds.fnames[i]) for i in idxs[:3]], 3, figsize=(9,3));
# vec = en_vecd['boat']
vec = ft_vecs.get_word_vector('seafood')
idxs,dists = get_knn(nn_predwv, vec)
show_imgs([open_image(PATH/md.val_ds.fnames[i]) for i in idxs[:6]], 3, figsize=(9,3));
vec = (ft_vecs.get_word_vector('roast') + ft_vecs.get_word_vector('chicken'))/2
idxs,dists = get_knn(nn_predwv, vec)
show_imgs([open_image(PATH/md.val_ds.fnames[i]) for i in idxs[:3]], 3, figsize=(9,3));
vec = (ft_vecs.get_word_vector('egg') + ft_vecs.get_word_vector('mix'))/2
idxs,dists = get_knn(nn_predwv, vec)
show_imgs([open_image(PATH/md.val_ds.fnames[i]) for i in idxs[:3]], 3, figsize=(9,3));
fname = '/data/food-101/Tiramisu_with_blueberries_and_raspberries.jpg'
img = open_image(fname)
show_img(img); # from https://en.wikipedia.org/wiki/Tiramisu
t_img = md.val_ds.transform(img)
pred = learn.predict_array(t_img[None])
idxs,dists = get_knn(nn_predwv, pred)
show_imgs([open_image(PATH/md.val_ds.fnames[i]) for i in idxs[1:4]], 3, figsize=(9,3));