#!/usr/bin/env python # coding: utf-8 # # MURA | Abnormality detection # - Author: [Pierre Guillou](https://www.linkedin.com/in/pierreguillou) # - Date: March 2019 # - MURA Dataset: https://stanfordmlgroup.github.io/competitions/mura/ # - Ref: [Fastai v1](https://docs.fast.ai/) (Deep Learning library on PyTorch) # - Post in medium: https://medium.com/@pierre_guillou/fastai-the-new-radiology-tool-76f02c1e25bf # ## What is MURA? # MURA (musculoskeletal radiographs) is a large dataset of bone X-rays. Algorithms are tasked with determining whether an X-ray study is normal or abnormal. # # Musculoskeletal conditions affect more than 1.7 billion people worldwide, and are the most common cause of severe, long-term pain and disability, with 30 million emergency department visits annually and increasing. We hope that our dataset can lead to significant advances in medical imaging technologies which can diagnose at the level of experts, towards improving healthcare access in parts of the world where access to skilled radiologists is limited. # # MURA is one of the largest public radiographic image datasets. We're making this dataset available to the community and hosting a competition to see if your models can perform as well as radiologists on the task. # ## Initialisation # In[1]: get_ipython().run_line_magic('reload_ext', 'autoreload') get_ipython().run_line_magic('autoreload', '2') get_ipython().run_line_magic('matplotlib', 'inline') # In[2]: from fastai.vision import * from fastai.widgets import * import shutil # In[3]: from fastai.callbacks import * # In[4]: import fastai print(f'fastai: {fastai.__version__}') print(f'cuda: {torch.cuda.is_available()}') # In[5]: import gc import torch # ## Data # ### Data path # In[6]: Config.data_path() # In[7]: path = Config.data_path() / 'MURA-v1.1' path.ls() # ### csv # In[8]: df_train = pd.read_csv(path / 'train_image_paths.csv', header=None, names=['image']) df_train.image[0] # In[9]: df_valid = pd.read_csv(path / 'valid_image_paths.csv', header=None, names=['image']) df_valid.image[0] # In[10]: df_train_label = pd.read_csv(path / 'train_labeled_studies.csv', header=None, names=['image', 'label']) df_train_label.head() # ### Create folder data2 to train models # In[11]: # Create data2 with data to train our models path_train = path / 'data2/train' path_valid = path / 'data2/valid' path_train.mkdir(parents=True, exist_ok=True) path_valid.mkdir(parents=True, exist_ok=True) # In[12]: path_train_neg = path_train / '0' path_train_pos = path_train / '1' path_train_neg.mkdir(parents=True, exist_ok=True) path_train_pos.mkdir(parents=True, exist_ok=True) path_valid_neg = path_valid / '0' path_valid_pos = path_valid / '1' path_valid_neg.mkdir(parents=True, exist_ok=True) path_valid_pos.mkdir(parents=True, exist_ok=True) # ### Get list of images # In[13]: fnames_train = get_image_files(path/'train', recurse=True) print(len(fnames_train)) fnames_train[:5] # In[14]: fnames_valid = get_image_files(path/'valid', recurse=True) print(len(fnames_valid)) fnames_valid[:5] # ### Copy images into data2 # In[15]: pat_label = re.compile(r'/XR_([^/]+)/[^/]+/[^/]+/[^/]+.png$') pat_patient = re.compile(r'/[^/]+/patient([^/]+)/[^/]+/[^/]+.png$') pat_study = re.compile(r'/[^/]+/[^/]+/([^/]+)/[^/]+.png$') # pat_study_negpos = re.compile(r'\\[^\\]+\\[^\\]+\\study\d+_([^\\]+)\\[^\\]+.png$') # In[39]: get_ipython().run_cell_magic('time', '', "# copy all train images in corresponding class folders under MURA-v1.1/data2/train\nfor src in fnames_train:\n # get image label\n label = pat_label.search(str(src))\n label = label.group(1)\n # get patient number\n patient = pat_patient.search(str(src))\n patient = patient.group(1)\n # get study name\n study = pat_study.search(str(src))\n study = study.group(1) \n # create class folder if necessary\n if 'negative' in study:\n path_label = path_train_neg\n else:\n path_label = path_train_pos\n # copy image to its class folder\n img_name = label + '_patient' + patient + '_' + study + '_' + src.name\n dest = path_label / img_name\n shutil.copy(str(src), str(dest))\n") # In[40]: get_ipython().run_cell_magic('time', '', "# copy all valid images in corresponding class folders under MURA-v1.1/data2/valid\nfor src in fnames_valid:\n # get image label\n label = pat_label.search(str(src))\n label = label.group(1)\n # get patient number\n patient = pat_patient.search(str(src))\n patient = patient.group(1)\n # get study name\n study = pat_study.search(str(src))\n study = study.group(1) \n # create class folder if necessary\n if 'negative' in study:\n path_label = path_valid_neg\n else:\n path_label = path_valid_pos\n # copy image to its class folder\n img_name = label + '_patient' + patient + '_' + study + '_' + src.name\n dest = path_label / img_name\n shutil.copy(str(src), str(dest))\n") # ### Number of studies # In[16]: pat_label = re.compile(r'/XR_([^/]+)/[^/]+/[^/]+/[^/]+.png$') pat_patient = re.compile(r'/[^/]+/patient([^/]+)/[^/]+/[^/]+.png$') pat_study = re.compile(r'/([^/]+)_[^/]+/[^/]+.png$') # In[17]: mura = ['elbow', 'finger', 'forearm', 'hand', 'humerus', 'shoulder', 'wrist'] study_train_dict = dict() study_valid_dict = dict() for m in mura: study_train_dict[m] = list() study_valid_dict[m] = list() for src in fnames_train: # get image label label = pat_label.search(str(src)) label = label.group(1) # get patient number patient = pat_patient.search(str(src)) patient = patient.group(1) # get study name study = pat_study.search(str(src)) study = study.group(1) # add to label list s = 'patient' + patient + '_' + study study_train_dict[label.lower()].append(s) for src in fnames_valid: # get image label label = pat_label.search(str(src)) label = label.group(1) # get patient number patient = pat_patient.search(str(src)) patient = patient.group(1) # get study name study = pat_study.search(str(src)) study = study.group(1) # add to label list s = 'patient' + patient + '_' + study study_valid_dict[label.lower()].append(s) # In[18]: num_train_studies = 0 num_valid_studies = 0 for m in mura: # train myset = set(study_train_dict[m]) num_train_studies += len(myset) # valid myset = set(study_valid_dict[m]) num_valid_studies += len(myset) # In[19]: # 207 studies in test num_train_studies, num_valid_studies, num_train_studies + num_valid_studies + 207 # ## Training with resnet34 # ### size = 112 # In[93]: size = 112 bs = 512 np.random.seed(42) data = ImageDataBunch.from_folder(path/'data2', ds_tfms=get_transforms(flip_vert=True, max_warp=0.), size=size, bs=bs, ).normalize(imagenet_stats) # In[19]: data.show_batch(rows=3, figsize=(7,6)) # In[20]: data.classes # In[21]: len(data.train_ds), len(data.valid_ds), len(data.train_ds) + len(data.valid_ds) # In[22]: plt.bar([0,1], [len(path_train_neg.ls()), len(path_train_pos.ls())]) plt.show() # In[94]: learn = cnn_learner(data, models.resnet34, metrics=[error_rate, accuracy], wd=0.1) # In[17]: learn.fit_one_cycle(5,callbacks=[ShowGraph(learn),SaveModelCallback(learn)]) # In[39]: learn.fit_one_cycle(5,callbacks=[ShowGraph(learn),SaveModelCallback(learn)]) # In[40]: learn.save('resnet34-stage-1') # In[21]: learn.load('resnet34-stage-1'); # In[41]: learn.purge() learn.unfreeze() # In[42]: learn.lr_find() # In[43]: learn.recorder.plot() # In[25]: lr=1e-4 learn.fit_one_cycle(5,max_lr=slice(lr/100,lr),callbacks=[ShowGraph(learn),SaveModelCallback(learn)]) # In[44]: lr=1e-4 learn.fit_one_cycle(5,max_lr=slice(lr/100,lr),callbacks=[ShowGraph(learn),SaveModelCallback(learn)]) # In[45]: learn.save('resnet34-stage-2') # #### Results by image # In[95]: size = 112 bs = 512 np.random.seed(42) data = ImageDataBunch.from_folder(path/'data2', ds_tfms=get_transforms(flip_vert=True, max_warp=0.), size=size, bs=bs, ).normalize(imagenet_stats) # In[96]: learn = cnn_learner(data, models.resnet34, metrics=[error_rate, accuracy], wd=0.1) # In[97]: learn.load('resnet34-stage-2'); # In[98]: interp = ClassificationInterpretation.from_learner(learn) losses,idxs = interp.top_losses() len(data.valid_ds)==len(losses)==len(idxs) # In[99]: # first interpretation interp.plot_confusion_matrix() # In[100]: interp.most_confused(min_val=2) # In[101]: interp.plot_top_losses(9, figsize=(15,11)) # #### Results by study # In[102]: learn.load('resnet34-stage-2'); # In[103]: # validation preds_val, y_val = learn.get_preds() # In[104]: preds_val # In[105]: len(preds_val) # In[106]: for img_url in data.valid_ds.x.items: print(img_url) break # In[107]: pat_label = re.compile(r'/([^/]+)_patient[^/]+.png$') pat_study = re.compile(r'/([^/]+)_[^_]+.png$') # In[108]: get_ipython().run_cell_magic('time', '', 'studies = dict()\nstudies_num = dict()\nlabels_num = dict()\n\nfor m in mura:\n labels_num[m] = 0\n\nfor idx, src in enumerate(data.valid_ds.x.items):\n # get label name\n label = pat_label.search(str(src))\n label = label.group(1) \n # get study name\n study = pat_study.search(str(src))\n study = study.group(1) \n # sum probabilities by study\n if study in studies:\n studies[study] += preds_val[idx,:].clone()\n studies_num[study] += 1\n else:\n studies[study] = preds_val[idx,:].clone()\n studies_num[study] = 1\n labels[study] = label\n') # In[110]: labels_num = dict() for m in mura: labels_num[m] = sum([1 for k,v in labels.items() if v.lower() == m]) # In[111]: print(labels_num) print(sum([v for k,v in labels_num.items()])) # In[113]: len(studies) # In[115]: len(studies_num) # In[116]: # get averages for (k,v) in studies.items(): studies[k] = studies[k] / studies_num[k] # In[118]: # get predictions by study acc = 0. for (k,v) in studies.items(): prob, y_hat = torch.max(studies[k],0) if 'negative' in k: acc += (0 == y_hat.item()) else: acc += (1 == y_hat.item()) # print(f'{k} {y_hat.item()} ({prob})') # In[119]: len(studies), acc # In[120]: # get study accuracy total print(f'study accuracy total: {round(acc / len(studies),3)}') # In[121]: # get predictions by study and label acc_label = dict() for m in mura: acc_label[m] = 0 for (k,v) in studies.items(): prob, y_hat = torch.max(studies[k],0) label = labels[k] if 'negative' in k: acc_label[label.lower()] += (0 == y_hat.item()) else: acc_label[label.lower()] += (1 == y_hat.item()) # In[122]: acc_label # In[123]: sum([v for k,v in acc_label.items()]) # In[124]: labels_num # In[125]: # get study accuracy by label for m in mura: print(f'{m}: {round(acc_label[m] / labels_num[m],3)}') # ### size = 224 # In[149]: learn = None gc.collect() # In[150]: torch.cuda.empty_cache() # In[151]: size = 224 bs = 128 np.random.seed(42) data = ImageDataBunch.from_folder(path/'data2', ds_tfms=get_transforms(flip_vert=True, max_warp=0.), size=size, bs=bs, ).normalize(imagenet_stats) # In[152]: learn = cnn_learner(data, models.resnet34, metrics=[error_rate, accuracy], wd=0.1) # In[153]: learn.load('resnet34-stage-2'); # In[154]: learn.fit_one_cycle(5,callbacks=[ShowGraph(learn),SaveModelCallback(learn)]) # In[155]: learn.save('resnet34-stage-3'); # In[156]: learn = None gc.collect() torch.cuda.empty_cache() # In[157]: learn = cnn_learner(data, models.resnet34, metrics=[error_rate, accuracy], wd=0.1) learn.load('resnet34-stage-3'); # In[158]: learn.lr_find() # In[159]: learn.recorder.plot() # In[160]: lr = 3e-4 learn.fit_one_cycle(5,max_lr=lr,callbacks=[ShowGraph(learn),SaveModelCallback(learn)]) # In[161]: learn.save('resnet34-stage-4') # In[36]: learn.load('resnet34-stage-4'); # In[162]: learn.purge() learn.unfreeze() # In[163]: learn.lr_find() # In[164]: learn.recorder.plot() # In[165]: lr=3e-6 learn.fit_one_cycle(5,max_lr=slice(lr/100,lr),callbacks=[ShowGraph(learn),SaveModelCallback(learn)]) # In[166]: learn.save('resnet34-stage-5') # #### Results by image # In[167]: size = 224 bs = 128 np.random.seed(42) data = ImageDataBunch.from_folder(path/'data2', ds_tfms=get_transforms(flip_vert=True, max_warp=0.), size=size, bs=bs, ).normalize(imagenet_stats) # In[168]: learn = cnn_learner(data, models.resnet34, metrics=[error_rate, accuracy], wd=0.1) # In[169]: learn.load('resnet34-stage-5'); # In[170]: interp = ClassificationInterpretation.from_learner(learn) losses,idxs = interp.top_losses() len(data.valid_ds)==len(losses)==len(idxs) # In[171]: # first interpretation interp.plot_confusion_matrix() # In[172]: interp.most_confused(min_val=2) # In[173]: interp.plot_top_losses(9, figsize=(15,11)) # #### Results by study # In[174]: learn.load('resnet34-stage-5'); # In[175]: # validation preds_val, y_val = learn.get_preds() # In[176]: preds_val # In[177]: len(preds_val) # In[178]: for img_url in data.valid_ds.x.items: print(img_url) break # In[179]: pat_label = re.compile(r'/([^/]+)_patient[^/]+.png$') pat_study = re.compile(r'/([^/]+)_[^_]+.png$') # In[180]: get_ipython().run_cell_magic('time', '', 'studies = dict()\nstudies_num = dict()\nlabels_num = dict()\n\nfor m in mura:\n labels_num[m] = 0\n\nfor idx, src in enumerate(data.valid_ds.x.items):\n # get label name\n label = pat_label.search(str(src))\n label = label.group(1) \n # get study name\n study = pat_study.search(str(src))\n study = study.group(1) \n # sum probabilities by study\n if study in studies:\n studies[study] += preds_val[idx,:].clone()\n studies_num[study] += 1\n else:\n studies[study] = preds_val[idx,:].clone()\n studies_num[study] = 1\n labels[study] = label\n') # In[182]: labels_num = dict() for m in mura: labels_num[m] = sum([1 for k,v in labels.items() if v.lower() == m]) # In[183]: print(labels_num) print(sum([v for k,v in labels_num.items()])) # In[185]: len(studies) # In[187]: len(studies_num) # In[188]: # get averages for (k,v) in studies.items(): studies[k] = studies[k] / studies_num[k] # In[190]: # get predictions by study acc = 0. for (k,v) in studies.items(): prob, y_hat = torch.max(studies[k],0) if 'negative' in k: acc += (0 == y_hat.item()) else: acc += (1 == y_hat.item()) # print(f'{k} {y_hat.item()} ({prob})') # In[191]: len(studies), acc # In[192]: # get study accuracy total print(f'study accuracy total: {round(acc / len(studies),3)}') # In[193]: # get predictions by study and label acc_label = dict() for m in mura: acc_label[m] = 0 for (k,v) in studies.items(): prob, y_hat = torch.max(studies[k],0) label = labels[k] if 'negative' in k: acc_label[label.lower()] += (0 == y_hat.item()) else: acc_label[label.lower()] += (1 == y_hat.item()) # In[194]: acc_label # In[195]: sum([v for k,v in acc_label.items()]) # In[196]: labels_num # In[197]: # get study accuracy by label for m in mura: print(f'{m}: {round(acc_label[m] / labels_num[m],3)}') # ## Training with densenet169 # ### size = 112 # In[29]: import gc import torch # In[30]: learn = None gc.collect() # In[31]: torch.cuda.empty_cache() # In[62]: size = 112 bs = 256 np.random.seed(42) data = ImageDataBunch.from_folder(path/'data2', ds_tfms=get_transforms(flip_vert=True, max_warp=0.), size=size, bs=bs, ).normalize(imagenet_stats) # In[63]: learn = cnn_learner(data, models.densenet169, metrics=[error_rate, accuracy], wd=0.1) learn = learn.to_fp16() # In[64]: learn.fit_one_cycle(5,callbacks=[ShowGraph(learn),SaveModelCallback(learn)]) # In[65]: learn.save('densenet169-stage-1') # In[66]: learn.load('densenet169-stage-1'); learn = learn.to_fp16() # In[67]: learn.purge() learn.unfreeze() # In[68]: learn.lr_find() # In[69]: learn.recorder.plot() # In[70]: lr=1e-4 learn.fit_one_cycle(5,max_lr=slice(lr/100,lr),callbacks=[ShowGraph(learn),SaveModelCallback(learn)]) # In[71]: learn.save('densenet169-stage-2') # #### Results by image # In[72]: size = 112 bs = 256 np.random.seed(42) data = ImageDataBunch.from_folder(path/'data2', ds_tfms=get_transforms(flip_vert=True, max_warp=0.), size=size, bs=bs, ).normalize(imagenet_stats) # In[73]: learn = cnn_learner(data, models.densenet169, metrics=[error_rate, accuracy], wd=0.1) # In[74]: learn.load('densenet169-stage-2'); # In[75]: interp = ClassificationInterpretation.from_learner(learn) losses,idxs = interp.top_losses() len(data.valid_ds)==len(losses)==len(idxs) # In[76]: # first interpretation interp.plot_confusion_matrix() # In[77]: interp.most_confused(min_val=2) # In[78]: interp.plot_top_losses(9, figsize=(15,11)) # #### Results by study # In[140]: learn.load('densenet169-stage-2'); # In[141]: # validation preds_val, y_val = learn.get_preds() # In[142]: preds_val # In[143]: len(preds_val) # In[144]: for img_url in data.valid_ds.x.items: print(img_url) break # In[145]: pat_label = re.compile(r'/([^/]+)_patient[^/]+.png$') pat_study = re.compile(r'/([^/]+)_[^_]+.png$') # In[157]: get_ipython().run_cell_magic('time', '', 'studies = dict()\nstudies_num = dict()\nlabels = dict()\n\nfor m in mura:\n labels_num[m] = 0\n\nfor idx, src in enumerate(data.valid_ds.x.items):\n # get label name\n label = pat_label.search(str(src))\n label = label.group(1) \n # get study name\n study = pat_study.search(str(src))\n study = study.group(1) \n # sum probabilities by study\n if study in studies:\n studies[study] += preds_val[idx,:].clone()\n studies_num[study] += 1\n else:\n studies[study] = preds_val[idx,:].clone()\n studies_num[study] = 1\n labels[study] = label\n') # In[162]: labels_num = dict() for m in mura: labels_num[m] = sum([1 for k,v in labels.items() if v.lower() == m]) # In[167]: print(labels_num) print(sum([v for k,v in labels_num.items()])) # In[170]: len(studies) # In[172]: len(studies_num) # In[173]: # get averages for (k,v) in studies.items(): studies[k] = studies[k] / studies_num[k] # In[175]: # get predictions by study acc = 0. for (k,v) in studies.items(): prob, y_hat = torch.max(studies[k],0) if 'negative' in k: acc += (0 == y_hat.item()) else: acc += (1 == y_hat.item()) # print(f'{k} {y_hat.item()} ({prob})') # In[176]: len(studies), acc # In[177]: # get study accuracy total print(f'study accuracy total: {round(acc / len(studies),3)}') # In[178]: # get predictions by study and label acc_label = dict() for m in mura: acc_label[m] = 0 for (k,v) in studies.items(): prob, y_hat = torch.max(studies[k],0) label = labels[k] if 'negative' in k: acc_label[label.lower()] += (0 == y_hat.item()) else: acc_label[label.lower()] += (1 == y_hat.item()) # In[179]: acc_label # In[180]: sum([v for k,v in acc_label.items()]) # In[181]: labels_num # In[185]: # get study accuracy by label for m in mura: print(f'{m}: {round(acc_label[m] / labels_num[m],3)}') # ### size = 224 + use of the kappa metrics # In[186]: import gc import torch # In[187]: learn = None gc.collect() # In[32]: torch.cuda.empty_cache() # In[20]: size = 224 bs = 64 np.random.seed(42) data = ImageDataBunch.from_folder(path/'data2', ds_tfms=get_transforms(flip_vert=True, max_warp=0.), size=size, bs=bs, ).normalize(imagenet_stats) # In[23]: kappa = KappaScore() kappa.weights = "quadratic" # In[24]: learn = cnn_learner(data, models.densenet169, metrics=[error_rate, accuracy, kappa], wd=0.1).to_fp16() learn.load('densenet169-stage-2'); # In[191]: learn.purge() learn.freeze() # In[192]: learn.fit_one_cycle(10,callbacks=[ShowGraph(learn),SaveModelCallback(learn)]) # In[193]: learn.save('densenet169-stage-3') # In[194]: learn = None gc.collect() torch.cuda.empty_cache() # In[35]: learn = cnn_learner(data, models.densenet169, metrics=[error_rate, accuracy], wd=0.1).to_fp16() learn.load('densenet169-stage-3'); # In[196]: learn.lr_find() # In[197]: learn.recorder.plot() # In[198]: lr = 1e-4 learn.fit_one_cycle(5,max_lr=lr,callbacks=[ShowGraph(learn),SaveModelCallback(learn)]) # In[199]: learn.save('densenet169-stage-4') # In[36]: learn.load('densenet169-stage-4'); # In[37]: learn.purge() learn.unfreeze() # In[202]: learn.lr_find() # In[203]: learn.recorder.plot() # In[38]: lr=3e-5 learn.fit_one_cycle(5,max_lr=slice(lr/100,lr),callbacks=[ShowGraph(learn),SaveModelCallback(learn)]) # In[39]: learn.save('densenet169-stage-5') # In[26]: learn.load('densenet169-stage-4'); # In[27]: lr=3e-5 learn.fit_one_cycle(3,max_lr=slice(lr/100,lr),callbacks=[ShowGraph(learn),SaveModelCallback(learn)]) # In[39]: learn.save('densenet169-stage-5') # #### Results by image # In[50]: size = 224 bs = 64 np.random.seed(42) data = ImageDataBunch.from_folder(path/'data2', ds_tfms=get_transforms(flip_vert=True, max_warp=0.), size=size, bs=bs, ).normalize(imagenet_stats) # In[51]: learn = cnn_learner(data, models.densenet169, metrics=[error_rate, accuracy], wd=0.1) # In[52]: learn.load('densenet169-stage-5'); # In[53]: interp = ClassificationInterpretation.from_learner(learn) losses,idxs = interp.top_losses() len(data.valid_ds)==len(losses)==len(idxs) # In[54]: # first interpretation interp.plot_confusion_matrix() # In[55]: interp.most_confused(min_val=2) # In[56]: interp.plot_top_losses(9, figsize=(15,11)) # #### Results by study # In[57]: learn.load('densenet169-stage-5'); # In[58]: # validation preds_val, y_val = learn.get_preds() # In[59]: preds_val # In[60]: len(preds_val) # In[61]: for img_url in data.valid_ds.x.items: print(img_url) break # In[62]: pat_label = re.compile(r'/([^/]+)_patient[^/]+.png$') pat_study = re.compile(r'/([^/]+)_[^_]+.png$') # In[75]: get_ipython().run_cell_magic('time', '', 'studies = dict()\nstudies_num = dict()\nlabels_num = dict()\n\nfor m in mura:\n labels_num[m] = 0\n\nfor idx, src in enumerate(data.valid_ds.x.items):\n # get label name\n label = pat_label.search(str(src))\n label = label.group(1) \n # get study name\n study = pat_study.search(str(src))\n study = study.group(1) \n # sum probabilities by study\n if study in studies:\n studies[study] += preds_val[idx,:].clone()\n studies_num[study] += 1\n else:\n studies[study] = preds_val[idx,:].clone()\n studies_num[study] = 1\n labels[study] = label\n') # In[77]: labels_num = dict() for m in mura: labels_num[m] = sum([1 for k,v in labels.items() if v.lower() == m]) # In[78]: print(labels_num) print(sum([v for k,v in labels_num.items()])) # In[80]: len(studies) # In[82]: len(studies_num) # In[83]: # get averages for (k,v) in studies.items(): studies[k] = studies[k] / studies_num[k] # In[85]: # get predictions by study acc = 0. for (k,v) in studies.items(): prob, y_hat = torch.max(studies[k],0) if 'negative' in k: acc += (0 == y_hat.item()) else: acc += (1 == y_hat.item()) # print(f'{k} {y_hat.item()} ({prob})') # In[86]: len(studies), acc # In[87]: # get study accuracy total print(f'study accuracy total: {round(acc / len(studies),3)}') # In[88]: # get predictions by study and label acc_label = dict() for m in mura: acc_label[m] = 0 for (k,v) in studies.items(): prob, y_hat = torch.max(studies[k],0) label = labels[k] if 'negative' in k: acc_label[label.lower()] += (0 == y_hat.item()) else: acc_label[label.lower()] += (1 == y_hat.item()) # In[89]: acc_label # In[90]: sum([v for k,v in acc_label.items()]) # In[91]: labels_num # In[92]: # get study accuracy by label for m in mura: print(f'{m}: {round(acc_label[m] / labels_num[m],3)}')