In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
In [2]:
from fastai import *
from fastai.vision import *
import xml.etree.ElementTree as ET
import pandas as pd
import os
import glob2 as glob
In [3]:
PATH = Path('/data2/imageCLEF2013/')
In [9]:
trn_df = pd.DataFrame()
image_list = []

for f in (PATH/'train').glob('*.xml'):
    tree = ET.parse(f)
    fn = os.path.join('train',tree.find('FileName').text)
    plantid = tree.find('IndividualPlantId').text
    bg_type = tree.find('Type').text
    content = tree.find('Content').text
    class_id = tree.find('ClassId').text
    image_list.append([fn, plantid, bg_type, content, class_id])
In [19]:
trn_df = pd.DataFrame(image_list, columns=['filename', 'plantid', 'background_type', 'content', 'classid'])
In [20]:
trn_df.to_csv(PATH/'train.csv', index=False)
In [4]:
trn_df = pd.read_csv(PATH/'train.csv')
In [5]:
trn_df.head()
Out[5]:
filename plantid background_type content classid
0 train/0.jpg 470 NaturalBackground Leaf Corylus avellana
1 train/1.jpg 246 SheetAsBackground Leaf Ruscus aculeatus
2 train/2.jpg 94 SheetAsBackground Leaf Phillyrea angustifolia
3 train/3.jpg 742 SheetAsBackground Leaf Rhamnus alaternus
4 train/4.jpg 630 NaturalBackground Leaf Hedera helix
In [9]:
trn_df['classid'].value_counts()
Out[9]:
Quercus ilex                   437
Ulmus minor                    398
Viburnum tinus                 383
Phillyrea angustifolia         358
Hedera helix                   348
Cercis siliquastrum            347
Olea europaea                  322
Celtis australis               319
Buxus sempervirens             305
Populus nigra                  305
Pittosporum tobira             304
Robinia pseudoacacia           299
Crataegus monogyna             297
Platanus x hispanica           285
Arbutus unedo                  284
Ruscus aculeatus               267
Betula pendula                 258
Acer negundo                   250
Nerium oleander                239
Cotinus coggygria              238
Euphorbia characias            236
Populus alba                   228
Aesculus hippocastanum         227
Corylus avellana               225
Diospyros kaki                 223
Acer campestre                 217
Acer monspessulanum            212
Ginkgo biloba                  201
Pistacia lentiscus             200
Rhus coriaria                  196
                              ... 
Limonium vulgare                19
Papaver somniferum              19
Convolvulus cantabrica          19
Aphyllanthes monspeliensis      19
Centaurea jacea                 18
Soldanella alpina               18
Ruta montana                    18
Cakile maritima                 17
Limodorum abortivum             17
Allium triquetrum               17
Monotropa hypopitys             17
Stachys sylvatica               17
Campanula trachelium            16
Sedum sediforme                 16
Gentianella ciliata             16
Potentilla recta                16
Cyanus montanus                 16
Globularia repens               16
Dryas octopetala                15
Linum campanulatum              15
Blechnum spicant                14
Plantago media                  14
Gentiana verna                  14
Gentiana pneumonanthe           13
Aster alpinus                   13
Populus trichocarpa             13
Asplenium scolopendrium         13
Parthenocissus tricuspidata     12
Matthiola sinuata               12
Carthamus lanatus               11
Name: classid, Length: 250, dtype: int64
In [5]:
data = ImageDataBunch.from_df(PATH, trn_df, fn_col=0, 
                              label_col=4, ds_tfms=get_transforms(flip_vert=True), size=512, bs=24)

Plot image sizes in order to test progressive resizing limits

In [6]:
train_fnames = data.train_ds.ds.x
In [7]:
type(train_fnames)
Out[7]:
numpy.ndarray
In [10]:
import PIL
In [11]:
size_d = {k: PIL.Image.open(k).size for k in train_fnames}
In [12]:
row_sz, col_sz = list(zip(*size_d.values()))
In [13]:
row_sz = np.array(row_sz); col_sz = np.array(col_sz)
In [14]:
plt.hist(row_sz);
In [11]:
len(data.classes)
Out[11]:
250
In [8]:
data.normalize(imagenet_stats)
In [10]:
data.show_batch(rows=4, figsize=(12, 12))