Summarization with blurr

blurr is a libray I started that integrates huggingface transformers with the world of fastai v2, giving fastai devs everything they need to train, evaluate, and deploy transformer specific models. In this article, I provide a simple example of how to use blurr's new summarization capabilities to train, evaluate, and deploy a BART summarization model.

Updated on 08/21/2020 to use fastai 2.0.0 and also demo batch-time padding.
Updated on 09/25/2020 to use on the fly batch-time tokenization.
Updated on 11/12/2020 with support for fastai >= 2.1.5 and mixed precision. Updated on 12/31/2020 (too much for one sentence, see the docs for more).

  • toc: false
  • badges: true
  • comments: true
  • author: Wayde Gilliam
  • categories: [fastai, huggingface, blurr, summarization, text generation]
  • image: images/articles/blurr-logo-small.png
  • hide: false
  • search_exclude: false
  • show_tags: true
In [1]:
# only run this cell if you are in collab
# !pip install ohmeow-blurr -q
# !pip install datasets -q
# !pip install bert-score -q
     |████████████████████████████████| 61kB 5.6MB/s 
     |████████████████████████████████| 225kB 12.9MB/s 
     |████████████████████████████████| 1.2MB 17.9MB/s 
     |████████████████████████████████| 204kB 38.2MB/s 
     |████████████████████████████████| 2.1MB 34.3MB/s 
     |████████████████████████████████| 51kB 8.8MB/s 
     |████████████████████████████████| 112kB 54.5MB/s 
     |████████████████████████████████| 245kB 58.9MB/s 
     |████████████████████████████████| 61kB 9.1MB/s 
     |████████████████████████████████| 901kB 55.8MB/s 
     |████████████████████████████████| 3.3MB 57.4MB/s 
  Building wheel for seqeval (setup.py) ... done
     |████████████████████████████████| 61kB 5.3MB/s 
In [2]:
import datasets
import pandas as pd
from fastai.text.all import *
from transformers import *

from blurr.data.all import *
from blurr.modeling.all import *

Data Preparation

We're going to use to use the datasets library from huggingface to grab your raw data. This package gives you access to all kinds of NLP related datasets, explanations of each, and various task specific metrics to use in evaluating your model. The best part being everything comes down to you in JSON! This makes it a breeze to get up and running quickly!

We'll just use a subset of the training set to build both our training and validation DataLoaders

In [3]:
raw_data = datasets.load_dataset('cnn_dailymail', '3.0.0', split='train[:1%]')

Downloading and preparing dataset cnn_dailymail/3.0.0 (download: 558.32 MiB, generated: 1.28 GiB, post-processed: Unknown size, total: 1.82 GiB) to /root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/3cb851bf7cf5826e45d49db2863f627cba583cbc32342df7349dfe6c38060234...







Dataset cnn_dailymail downloaded and prepared to /root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/3cb851bf7cf5826e45d49db2863f627cba583cbc32342df7349dfe6c38060234. Subsequent calls will reuse this data.
In [4]:
df = pd.DataFrame(raw_data)
df.head()
Out[4]:
article highlights id
0 It's official: U.S. President Barack Obama wants lawmakers to weigh in on whether to use military force in Syria. Obama sent a letter to the heads of the House and Senate on Saturday night, hours after announcing that he believes military action against Syrian targets is the right step to take over the alleged use of chemical weapons. The proposed legislation from Obama asks Congress to approve the use of military force "to deter, disrupt, prevent and degrade the potential for future uses of chemical weapons or other weapons of mass destruction." It's a step that is set to turn an internat... Syrian official: Obama climbed to the top of the tree, "doesn't know how to get down"\nObama sends a letter to the heads of the House and Senate .\nObama to seek congressional approval on military action against Syria .\nAim is to determine whether CW were used, not by whom, says U.N. spokesman . 0001d1afc246a7964130f43ae940af6bc6c57f01
1 (CNN) -- Usain Bolt rounded off the world championships Sunday by claiming his third gold in Moscow as he anchored Jamaica to victory in the men's 4x100m relay. The fastest man in the world charged clear of United States rival Justin Gatlin as the Jamaican quartet of Nesta Carter, Kemar Bailey-Cole, Nickel Ashmeade and Bolt won in 37.36 seconds. The U.S finished second in 37.56 seconds with Canada taking the bronze after Britain were disqualified for a faulty handover. The 26-year-old Bolt has now collected eight gold medals at world championships, equaling the record held by American trio... Usain Bolt wins third gold of world championship .\nAnchors Jamaica to 4x100m relay victory .\nEighth gold at the championships for Bolt .\nJamaica double up in women's 4x100m relay . 0002095e55fcbd3a2f366d9bf92a95433dc305ef
2 Kansas City, Missouri (CNN) -- The General Services Administration, already under investigation for lavish spending, allowed an employee to telecommute from Hawaii even though he is based at the GSA's Kansas City, Missouri, office, a CNN investigation has found. It cost more than $24,000 for the business development specialist to travel to and from the mainland United States over the past year. He is among several hundred GSA "virtual" workers who also travel to various conferences and their home offices, costing the agency millions of dollars over the past three years. Under the program, ... The employee in agency's Kansas City office is among hundreds of "virtual" workers .\nThe employee's travel to and from the mainland U.S. last year cost more than $24,000 .\nThe telecommuting program, like all GSA practices, is under review . 00027e965c8264c35cc1bc55556db388da82b07f
3 Los Angeles (CNN) -- A medical doctor in Vancouver, British Columbia, said Thursday that California arson suspect Harry Burkhart suffered from severe mental illness in 2010, when she examined him as part of a team of doctors. Dr. Blaga Stancheva, a family physician and specialist in obstetrics, said both Burkhart and his mother, Dorothee, were her patients in Vancouver while both were applying for refugee status in Canada. "I was asked to diagnose and treat Harry to support a claim explaining why he was unable to show up in a small-claims court case," Stancheva told CNN in a phone intervie... NEW: A Canadian doctor says she was part of a team examining Harry Burkhart in 2010 .\nNEW: Diagnosis: "autism, severe anxiety, post-traumatic stress disorder and depression"\nBurkhart is also suspected in a German arson probe, officials say .\nProsecutors believe the German national set a string of fires in Los Angeles . 0002c17436637c4fe1837c935c04de47adb18e9a
4 (CNN) -- Police arrested another teen Thursday, the sixth suspect jailed in connection with the gang rape of a 15-year-old girl on a northern California high school campus. Jose Carlos Montano, 18, was arrested on charges of felony rape, rape in concert with force, and penetration with a foreign object, said Richmond Police Lt. Mark Gagan. Montano was arrested Thursday evening in San Pablo, California, a small town about two miles from the city of Richmond, where the crime took place. Montano, who was held in lieu of $1.3 million bail, is accused of taking part in what police said was a 2½... Another arrest made in gang rape outside California school .\nInvestigators say up to 20 people took part or stood and watched the assault .\nFour suspects appeared in court Thursday; three wore bulletproof vests . 0003ad6ef0c37534f80b55b4235108024b407f0b

We begin by getting our hugginface objects needed for this task (e.g., the architecture, tokenizer, config, and model). We'll use blurr's get_hf_objects helper method here.

In [5]:
pretrained_model_name = "facebook/bart-large-cnn"
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR.get_hf_objects(pretrained_model_name, 
                                                                  model_cls=BartForConditionalGeneration)

hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)





Out[5]:
('bart',
 transformers.models.bart.configuration_bart.BartConfig,
 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast,
 transformers.models.bart.modeling_bart.BartForConditionalGeneration)

Next we need to build out our DataBlock. Remember tha a DataBlock is a blueprint describing how to move your raw data into something modelable. That blueprint is executed when we pass it a data source, which in our case, will be the DataFrame we created above. We'll use a random subset to get things moving along a bit faster for the demo as well.

Notice that the blurr DataBlock as been dramatically simplified given the shift to on-the-fly batch-time tokenization. All we need is to define a single HF_Seq2SeqBeforeBatchTransform instance, optionally passing a list to any of the tokenization arguments to differentiate the values for the input and summary sequences. In addition to specifying a custom max length for the inputs, we can also do the same for the output sequences ... and with the latest release of blurr, we can even customize the text generation by passing in text_gen_kwargs.

We pass noop as a type transform for our targets because everything is already handled by the batch transform now.

In [6]:
text_gen_kwargs = default_text_gen_kwargs(hf_config, hf_model, task='summarization'); text_gen_kwargs
Out[6]:
{'bad_words_ids': None,
 'bos_token_id': 0,
 'decoder_start_token_id': 2,
 'diversity_penalty': 0.0,
 'do_sample': False,
 'early_stopping': True,
 'encoder_no_repeat_ngram_size': 0,
 'eos_token_id': 2,
 'forced_bos_token_id': 0,
 'forced_eos_token_id': 2,
 'length_penalty': 2.0,
 'max_length': 142,
 'min_length': 56,
 'no_repeat_ngram_size': 3,
 'num_beam_groups': 1,
 'num_beams': 4,
 'num_return_sequences': 1,
 'output_attentions': False,
 'output_hidden_states': False,
 'output_scores': False,
 'pad_token_id': 1,
 'remove_invalid_values': False,
 'repetition_penalty': 1.0,
 'return_dict_in_generate': False,
 'temperature': 1.0,
 'top_k': 50,
 'top_p': 1.0,
 'use_cache': True}
In [7]:
hf_batch_tfm = HF_Seq2SeqBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model, 
                                              max_length=256, max_tgt_length=130, text_gen_kwargs=text_gen_kwargs)

blocks = (HF_Seq2SeqBlock(before_batch_tfm=hf_batch_tfm), noop)

dblock = DataBlock(blocks=blocks, get_x=ColReader('article'), get_y=ColReader('highlights'), splitter=RandomSplitter())
In [8]:
dls = dblock.dataloaders(df, bs=2)
In [9]:
len(dls.train.items), len(dls.valid.items)
Out[9]:
(2297, 574)

It's always a good idea to check out a batch of data and make sure the shapes look right.

In [10]:
b = dls.one_batch()
len(b), b[0]['input_ids'].shape, b[1].shape
Out[10]:
(2, torch.Size([2, 256]), torch.Size([2, 68]))

Even better, we can take advantage of blurr's TypeDispatched version of show_batch to look at things a bit more intuitively. We pass in the dls via the dataloaders argument so we can access all tokenization/modeling configuration stored in our batch transform above.

In [11]:
dls.show_batch(dataloaders=dls, max_n=2)
text target
0 The news from Pakistan is generally bad news. In the past week, which was far from atypical, suicide bombers attacked a court building in the northwestern city of Peshawar taking hostages and killing four people. In the southern city of Karachi the director of a renowned social program working in the megacity's poorest neighborhoods was shot and killed. And gunmen kidnapped two female Czech tourists in southwestern Pakistan. But this past week also saw more than a glimmer of good news from Pakistan: Saturday, March 16 marked an extraordinary moment in Pakistani history, as this is the first time a civilian government has served its entire five-year term (from 2008 to 2013). And, for the first time in its history, the Pakistani military appears unwilling to mount a coup against the civilian government. The military has successfully executed three coups and attempted a number of others since Pakistan's independence in 1947. Today the army understands that the most recent coup by General Pervez Musharraf who took power in 1999 has tarnished its brand. Musharraf hung on to power for almost a decade and his imposition of emergency rule in 2007 triggered massive street protests and eventually his ouster. On Saturday, Musharaf announced he is returning to Pakistan from self-imposed exile on March 24 to Peter Bergen: For the first time, Pakistan government served its full term.\nHe says lack of military coup attempt shows government is more stable than many think.\nElections in Pakistan, Afghanistan likely to be crucial for those two nations.\nBergen: He says Afghan economy is resilient and corruption may be receding.
1 (CNN) -- A controversial Colombian senator who has obtained the release of 16 hostages held by Marxist guerrillas is the leading candidate to receive this year's Nobel Peace Prize, which will be announced Friday, said an independent research institute in Norway. Sen. Piedad Cordoba, right, of Colombia reportedly is one of three top contenders for the Nobel Peace Prize. Sen. Piedad Cordoba is the most likely recipient among three leading contenders, said the Oslo-based International Peace Research Institute. The others the institute named are Jordanian Prince Ghazi bin Muhammad, a philosophy professor in Islamic faith at Jordan University, and Afghan physician and human rights activist Sima Samar. Though the institute considers Cordoba the front-runner, no single candidate has emerged as the clear-cut favorite, as sometimes happens, said Kristian Berg Harpviken, director of the peace institute. "It really is quite open this year," Harpviken said. This year's peace prize nominees include 172 people and 33 organizations. The committee does not release the names of the nominees. The 50-year-old peace institute, which is often called PRIO, has no connection with the Nobel committee that awards the peace prize. Harpviken said he believes the Independent research institute cites three top contenders for Nobel Peace Prize.\nNo candidate emerges as clear-cut favorite; winner to be announced Friday.\nColombian senator, Jordanian prince, Afghan rights activist among contenders.\nVietnamese Buddhist monk, Chinese dissidents also could be awarded prize.

Training

We'll prepare our BART model for training by wrapping it in blurr's HF_BaseModelWrapper object and using the callback, HF_BaseModelCallback, as usual. A new HF_Seq2SeqMetricsCallback object allows us to specify Seq2Seq metrics we want to use, things like rouge and bertscore for tasks like summarization as well as metrics such as meteor, bleu, and sacrebleu for translations tasks. Using huggingface's metrics library is as easy as specifying a metrics configuration such as below.

Once we have everything in place, we'll freeze our model so that only the last layer group's parameters of trainable. See here for our discriminitative learning rates work in fastai.

Note: This has been tested with ALOT of other Seq2Seq models; see the docs for more information.

In [12]:
seq2seq_metrics = {
        'rouge': {
            'compute_kwargs': { 'rouge_types': ["rouge1", "rouge2", "rougeL"], 'use_stemmer': True },
            'returns': ["rouge1", "rouge2", "rougeL"]
        },
        'bertscore': {
            'compute_kwargs': { 'lang': 'en' },
            'returns': ["precision", "recall", "f1"]
        }
    }
In [13]:
model = HF_BaseModelWrapper(hf_model)
learn_cbs = [HF_BaseModelCallback]
fit_cbs = [HF_Seq2SeqMetricsCallback(custom_metrics=seq2seq_metrics)]

learn = Learner(dls, 
                model,
                opt_func=ranger,
                loss_func=CrossEntropyLossFlat(),
                cbs=learn_cbs,
                splitter=partial(seq2seq_splitter, arch=hf_arch)).to_fp16()

learn.create_opt() 
learn.freeze()


Still experimenting with how to use fastai's learning rate finder for these kinds of models. If you all have any suggestions or interesting insights to share, please let me know. We're only going to train the frozen model for one epoch for this demo, but feel free to progressively unfreeze the model and train the other layers to see if you can best my results below.

In [14]:
learn.lr_find(suggestions=True)
Out[14]:
SuggestedLRs(lr_min=0.09120108485221863, lr_steep=0.7585775852203369)

It's also not a bad idea to run a batch through your model and make sure the shape of what goes in, and comes out, looks right.

In [15]:
b = dls.one_batch()
preds = learn.model(b[0])
len(preds),preds[0], preds[1].shape
Out[15]:
(4,
 tensor(3.8408, device='cuda:0', grad_fn=<NllLossBackward>),
 torch.Size([2, 68, 50264]))
In [16]:
learn.fit_one_cycle(1, lr_max=3e-5, cbs=fit_cbs)
epoch train_loss valid_loss rouge1 rouge2 rougeL bertscore_precision bertscore_recall bertscore_f1 time
0 1.697474 1.692343 0.375504 0.157680 0.253458 0.875942 0.890704 0.883181 12:32





And now we can look at the generated predictions using our text_gen_kwargs above

In [17]:
learn.show_results(learner=learn, max_n=2)
text target prediction
0 (CNN) -- Two weeks. Two gut-wrenching, frustrating, mysterious weeks. That's how long it's been since 227 passengers and 12 crew members boarded Malaysia Airlines Flight 370, destined for Beijing. A routine trip, it seemed, to catch up relatives in time for the weekend, start on a work assignment or just get away. Where they got to, still unknown. An exhaustive search -- covering a mind-boggling 2.97 million square miles, which is nearly the size of the continental United States -- has yielded some clues, but no proof of where the Boeing 777 is or definitively what happened to it. The latest, most notable lead revolved around two large objects detected by satellite Sunday floating on waters over 1,400 miles off of Australia's west coast. The first of several Australian military planes, as well as two long-range commercial jets, resumed their search Saturday morning to find any trace of the objects, amid some skepticism that they or ships in the area ever will and, if they do, that whatever they find will be related to the missing aircraft. Australian Prime Minister Tony Abbott on Friday defended the decision to announce the find, saying Australia owes it to families of those missing "to give them information as soon as it's NEW: Planes depart Australia to resume their search for airplane debris.\nNEW: Official: Passengers' relatives are moved to a different Kuala Lumpur hotel.\nObjects seen on satellite spark intensive search in southern Indian Ocean.\nU.S. officials: Files were deleted from flight simulator's hard drive after February 3. NEW: Australian military planes resume their search Saturday morning .\nThe search area is nearly the size of the continental United States .\nIt's been more than two weeks since Malaysia Airlines Flight 370 disappeared .\nA satellite detected two objects floating in waters over 1,400 miles off Australia's west coast .\nAustralia's prime minister defends the decision to announce the find .
1 U.N. weapons inspectors returned "overwhelming and indisputable" evidence of the use of nerve gas in Syria, Secretary-General Ban Ki-moon said Monday, calling the findings "beyond doubt and beyond the pale." The inspectors' 38-page report was released after Ban briefed Security Council members on its contents. The team found what it called "clear and convincing evidence" that the nerve agent sarin was delivered by surface-to-surface rockets "on a relatively large scale" in the suburbs of the Syrian capital Damascus on August 21. "It is the most significant confirmed use of chemical weapons against civilians since Saddam Hussein used them in Halabja in 1988, and the worst use of weapons of mass destruction in the 21st century," Ban said. "The international community has a responsibility to ensure that chemical weapons never re-emerge as an instrument of warfare," he said. Ban called the attack "a war crime" and a violation of treaties banning the use of chemical weapons that date back to 1925. But the inspectors' mandate did not include assigning blame for the attack, and Ban would not speculate on who launched the attack. The team did identify two types or rockets it said were used to deliver the gas and their trajectories, and international Syria findings "beyond doubt and beyond the pale," Ban says.\nU.S. to provide chemical protective gear to opposition, inspectors.\nSarin report demands "a unified and decisive response," Syrian opposition says.\nSyria says helicopter was shot down after straying into Turkish airspace. Ban Ki-moon calls the findings "beyond doubt and beyond the pale"\nInspectors' 38-page report released after Ban briefed Security Council members on its contents .\n"The international community has a responsibility to ensure that chemical weapons never re-emerge as an instrument of warfare," Ban says .\nThe inspectors' mandate did not include assigning blame for the attack, and Ban would not speculate on who launched the attack .

Even better though, blurr augments the fastai Learner with a blurr_summarize method that allows you to use huggingface's PreTrainedModel.generate method to create something more human-like.

In [18]:
test_article = """
The past 12 months have been the worst for aviation fatalities so far this decade - with the total of number of people killed if airline 
crashes reaching 1,050 even before the Air Asia plane vanished. Two incidents involving Malaysia Airlines planes - one over eastern Ukraine and the other in the Indian Ocean - led to the deaths of 537 people, while an Air Algerie crash in Mali killed 116 and TransAsia Airways crash in Taiwan killed a further 49 people. The remaining 456 fatalities were largely in incidents involving small commercial planes or private aircraft operating on behalf of companies, governments or organisations. Despite 2014 having the highest number of fatalities so far this decade, the total number of crashes was in fact the lowest since the first commercial jet airliner took off in 1949 - totalling just 111 across the whole world over the past 12 months. The all-time deadliest year for aviation was 1972 when a staggering 2,429 people were killed in a total of 55 plane crashes - including the crash of Aeroflot Flight 217, which killed 174 people in Russia, and Convair 990 Coronado, which claimed 155 lives in Spain. However this year's total death count of 1,212, including those presumed dead on board the missing Air Asia flight, marks a significant rise on the very low 265 fatalities in 2013 - which led to it being named the safest year in aviation since the end of the Second World War. Scroll down for videos. Deadly: The past 12 months have been the worst for aviation fatalities so far this decade - with the total of number of people killed if airline crashes reaching 1,158 even before the Air Asia plane (pictured) vanished. Fatal: Two incidents involving Malaysia Airlines planes - one over eastern Ukraine (pictured) and the other in the Indian Ocean - led to the deaths of 537 people. Surprising: Despite 2014 having the highest number of fatalities so far this decade, the total number of crashes was in fact the lowest since the first commercial jet airliner took off in 1949. 2014 has been a horrific year for Malaysia-based airlines, with 537 people dying on Malaysia Airlines planes, and a further 162 people missing and feared dead in this week's Air Asia incident. In total more than half the people killed in aviation incidents this year had been flying on board Malaysia-registered planes. In January a total of 12 people lost their lives in five separate incidents, while the same number of crashes in February killed 107. 
"""

We can override the text_gen_kwargs we specified for our DataLoaders when we generate text using blurr's Learner.blurr_generate method

In [19]:
outputs = learn.blurr_generate(test_article, early_stopping=True, num_beams=4, num_return_sequences=3)

for idx, o in enumerate(outputs):
    print(f'=== Prediction {idx+1} ===\n{o}\n')
=== Prediction 1 ===
 The past 12 months have been the worst for aviation fatalities so far this decade .
The total number of people killed if airline crashes reached 1,158 even before Air Asia plane vanished .
2014 has been a horrific year for Malaysia-based airlines, with 537 people dying on Malaysia Airlines planes, and a further 162 people missing and feared dead in this week's Air Asia incident .
More than half the people killed in aviation incidents this year had been flying on Malaysia-registered planes .

=== Prediction 2 ===
 The past 12 months have been the worst for aviation fatalities so far this decade .
The total number of people killed if airline crashes reached 1,158 even before Air Asia plane vanished .
2014 has been a horrific year for Malaysia-based airlines, with 537 people dying on Malaysia Airlines planes, and a further 162 people missing and feared dead in this week's Air Asia incident .
This year's total death count of 1,212 marks a significant rise on the very low 265 fatalities in 2013 .

=== Prediction 3 ===
 The past 12 months have been the worst for aviation fatalities so far this decade .
The total number of people killed if airline crashes reached 1,158 even before Air Asia plane vanished .
2014 has been a horrific year for Malaysia-based airlines, with 537 people dying on Malaysia Airlines planes, and a further 162 people missing and feared dead in this week's Air Asia incident .
More than half the people killed in aviation incidents this year had been flying on board Malaysia-registered planes .

What about inference? Easy!

In [20]:
learn.metrics = None
learn.export(fname='ft_cnndm_export.pkl')
In [21]:
inf_learn = load_learner(fname='ft_cnndm_export.pkl')
inf_learn.blurr_generate(test_article)
Out[21]:
[" The past 12 months have been the worst for aviation fatalities so far this decade .\nThe total number of people killed if airline crashes reached 1,158 even before Air Asia plane vanished .\n2014 has been a horrific year for Malaysia-based airlines, with 537 people dying on Malaysia Airlines planes, and a further 162 people missing and feared dead in this week's Air Asia incident .\nMore than half the people killed in aviation incidents this year had been flying on Malaysia-registered planes ."]

That's it

blurr supports a number of huggingface transformer model tasks in addition to summarization (e.g., sequence classification , token classification, and question/answering, causal language modeling, and transation). The docs include examples for each of these tasks if you're curious to learn more.

For more information about ohmeow or to get in contact with me, head over to ohmeow.com for all the details.

Thanks!

In [21]: