Preparing data for LLM training#

We are going to train our LLM using the PubMed dataset, which contains abstracts from biomedical journal articles. To keep things quick for the workshop, we will be working with a small subset of 50k abstracts. There are a few key concepts to consider when preparing data for LLM training:

  1. Tokenization: how do we turn the raw text strings into units of analysis for our program?

  2. Batching: how do we batch multiple documents into a batch data structure for efficient model training?

To address these issues, we will be using a pre-built biomedical data tokenizer available from the huggingface hub. Implementing tokenizers is a complicated topic on its own, and we will not deal with it in detail here.

One pleasant difference between LLMs and the previous generation of NLP techniques is that we do not usually need to perform elaborate data preprocessing to acheive good results.

PubMed Data#

The huggingface datasets library contains some useful utilities for loading and working with text data. We use this here.

from datasets import load_dataset
# path to folder containing "train.txt" and "test.txt" files containing train/test PubMed abstracts
root = "/project/rcde/datasets/pubmed/mesh_50k/splits/"

train_test_files = {
    "train": root+"train.txt",
    "test": root+"test.txt"
}

dataset = load_dataset("text", data_files = train_test_files).with_format("torch")

dataset
Found cached dataset text (/home/dane2/.cache/huggingface/datasets/text/default-cadbbf8acc2e2b5a/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2)
DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 45137
    })
    test: Dataset({
        features: ['text'],
        num_rows: 5036
    })
})

Let’s check the sizes of training and test sets:

len(dataset["train"]), len(dataset["test"])
(45137, 5036)

Look at a particular training sample:

dataset["train"][34799]
{'text': '"BACKGROUND & AIMS: The contribution of duodeno-gastroesophageal reflux to the development of Barrett\'s esophagus has remained an interesting but controversial topic. The present study assessed the risk for Barrett\'s esophagus after partial gastrectomy.METHODS: The data of outpatients from a medicine and gastroenterology clinic who underwent upper gastrointestinal endoscopy for any reason were analyzed in a case-control study. A case population of 650 patients with short- segment and 366 patients with long-segment Barrett\'s esophagus was compared in a multivariate logistic regression to a control population of 3047 subjects without Barrett\'s esophagus or other types of gastroesophageal reflux disease.RESULTS: In the case population, 25 (4%) patients with short-segment and 15 (4%) patients with long-segment Barrett\'s esophagus presented with a history of gastric surgery compared with 162 (5%) patients in the control population, yielding an adjusted odds ratio of 0.89 with a 95% confidence interval of 0.54-1.46 for short-segment and an adjusted odds ratio of 0.71 (0.30-1.72) for long-segment Barrett\'s esophagus. Similar results were obtained in separate analyses of 64 patients with Billroth-1 gastrectomy, 105 patients with Billroth-2 gastrectomy, and 33 patients with vagotomy and pyloroplasty for both short- and long-segment Barrett\'s esophagus. Caucasian ethnicity, the presence of hiatus hernia, and alcohol consumption were all associated with elevated risks for Barrett\'s esophagus.CONCLUSIONS: Gastric surgery for benign peptic ulcer disease is not a risk factor for either short- or long-segment Barrett\'s esophagus. This lack of association between gastric surgery and Barrett\'s esophagus suggests that reflux of bile without acid is not sufficient to damage the esophageal mucosa."'}

Tokenization#

Here, we will use the Huggingface transformers library to fetch a tokenizer purpose-built for biomedical data. The AutoTokenizer class allows us to provide the name of a model on the Huggingface Hub and automatically retrieve the associated tokenizer. We could experiment with different tokenizers to try to acheive better results.

from transformers import AutoTokenizer
# use a pretrained tokenizer
# https://huggingface.co/dmis-lab/biobert-base-cased-v1.2
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.2")

Let’s tokenize some text:

tokenizer(dataset["train"][34799]['text'])
{'input_ids': [101, 107, 3582, 111, 8469, 131, 1103, 6436, 1104, 6862, 2883, 1186, 118, 3245, 8005, 1279, 4184, 19911, 1348, 1231, 2087, 24796, 1106, 1103, 1718, 1104, 2927, 8127, 1204, 112, 188, 13936, 4184, 2328, 12909, 1144, 1915, 1126, 5426, 1133, 6241, 8366, 119, 1103, 1675, 2025, 14758, 1103, 3187, 1111, 2927, 8127, 1204, 112, 188, 13936, 4184, 2328, 12909, 1170, 7597, 3245, 7877, 5822, 18574, 119, 4069, 131, 1103, 2233, 1104, 1149, 27420, 1116, 1121, 170, 5182, 1105, 3245, 8005, 25195, 4807, 12257, 1150, 9315, 3105, 3245, 8005, 10879, 2556, 14196, 1322, 2155, 20739, 1111, 1251, 2255, 1127, 17689, 1107, 170, 1692, 118, 1654, 2025, 119, 170, 1692, 1416, 1104, 14166, 4420, 1114, 1603, 118, 6441, 1105, 3164, 1545, 4420, 1114, 1263, 118, 6441, 2927, 8127, 1204, 112, 188, 13936, 4184, 2328, 12909, 1108, 3402, 1107, 170, 4321, 8997, 15045, 9366, 5562, 1231, 24032, 1106, 170, 1654, 1416, 1104, 26714, 1559, 5174, 1443, 2927, 8127, 1204, 112, 188, 13936, 4184, 2328, 12909, 1137, 1168, 3322, 1104, 3245, 8005, 1279, 4184, 19911, 1348, 1231, 2087, 24796, 3653, 119, 2686, 131, 1107, 1103, 1692, 1416, 117, 1512, 113, 125, 110, 114, 4420, 1114, 1603, 118, 6441, 1105, 1405, 113, 125, 110, 114, 4420, 1114, 1263, 118, 6441, 2927, 8127, 1204, 112, 188, 13936, 4184, 2328, 12909, 2756, 1114, 170, 1607, 1104, 3245, 11048, 6059, 3402, 1114, 19163, 113, 126, 110, 114, 4420, 1107, 1103, 1654, 1416, 117, 23731, 1126, 10491, 10653, 6022, 1104, 121, 119, 5840, 1114, 170, 4573, 110, 6595, 14235, 1104, 121, 119, 4335, 118, 122, 119, 3993, 1111, 1603, 118, 6441, 1105, 1126, 10491, 10653, 6022, 1104, 121, 119, 5729, 113, 121, 119, 1476, 118, 122, 119, 5117, 114, 1111, 1263, 118, 6441, 2927, 8127, 1204, 112, 188, 13936, 4184, 2328, 12909, 119, 1861, 2686, 1127, 3836, 1107, 2767, 18460, 1104, 3324, 4420, 1114, 4550, 21941, 118, 122, 3245, 7877, 5822, 18574, 117, 8359, 4420, 1114, 4550, 21941, 118, 123, 3245, 7877, 5822, 18574, 117, 1105, 3081, 4420, 1114, 191, 19308, 18778, 1183, 1105, 185, 7777, 14824, 1643, 19268, 1183, 1111, 1241, 1603, 118, 1105, 1263, 118, 6441, 2927, 8127, 1204, 112, 188, 13936, 4184, 2328, 12909, 119, 11019, 23315, 11890, 21052, 117, 1103, 2915, 1104, 14938, 1123, 5813, 117, 1105, 6272, 8160, 1127, 1155, 2628, 1114, 8208, 11040, 1111, 2927, 8127, 1204, 112, 188, 13936, 4184, 2328, 12909, 119, 16421, 131, 3245, 11048, 6059, 1111, 26181, 11368, 185, 15384, 1596, 23449, 14840, 3653, 1110, 1136, 170, 3187, 5318, 1111, 1719, 1603, 118, 1137, 1263, 118, 6441, 2927, 8127, 1204, 112, 188, 13936, 4184, 2328, 12909, 119, 1142, 2960, 1104, 3852, 1206, 3245, 11048, 6059, 1105, 2927, 8127, 1204, 112, 188, 13936, 4184, 2328, 12909, 5401, 1115, 1231, 2087, 24796, 1104, 16516, 1513, 1443, 5190, 1110, 1136, 6664, 1106, 3290, 1103, 13936, 4184, 19911, 1348, 182, 21977, 9275, 119, 107, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

The input_ids list contains an encoded representation of our text. It is a sequence of integer IDs corresponding to the tokens that appear in the text we tokenized. The IDs refer to specific terms in a pre-defined vocabulary that came with the tokenizer. So the input_ids list can be decoded back into our original text.

We will usually want pytorch tensors, not lists, as output. For this we need to enable padding.

What do you think padding does?

ids = tokenizer(dataset["train"][34799:34801]['text'], return_tensors='pt', padding=True)['input_ids']
ids
tensor([[  101,   107,  3582,   111,  8469,   131,  1103,  6436,  1104,  6862,
          2883,  1186,   118,  3245,  8005,  1279,  4184, 19911,  1348,  1231,
          2087, 24796,  1106,  1103,  1718,  1104,  2927,  8127,  1204,   112,
           188, 13936,  4184,  2328, 12909,  1144,  1915,  1126,  5426,  1133,
          6241,  8366,   119,  1103,  1675,  2025, 14758,  1103,  3187,  1111,
          2927,  8127,  1204,   112,   188, 13936,  4184,  2328, 12909,  1170,
          7597,  3245,  7877,  5822, 18574,   119,  4069,   131,  1103,  2233,
          1104,  1149, 27420,  1116,  1121,   170,  5182,  1105,  3245,  8005,
         25195,  4807, 12257,  1150,  9315,  3105,  3245,  8005, 10879,  2556,
         14196,  1322,  2155, 20739,  1111,  1251,  2255,  1127, 17689,  1107,
           170,  1692,   118,  1654,  2025,   119,   170,  1692,  1416,  1104,
         14166,  4420,  1114,  1603,   118,  6441,  1105,  3164,  1545,  4420,
          1114,  1263,   118,  6441,  2927,  8127,  1204,   112,   188, 13936,
          4184,  2328, 12909,  1108,  3402,  1107,   170,  4321,  8997, 15045,
          9366,  5562,  1231, 24032,  1106,   170,  1654,  1416,  1104, 26714,
          1559,  5174,  1443,  2927,  8127,  1204,   112,   188, 13936,  4184,
          2328, 12909,  1137,  1168,  3322,  1104,  3245,  8005,  1279,  4184,
         19911,  1348,  1231,  2087, 24796,  3653,   119,  2686,   131,  1107,
          1103,  1692,  1416,   117,  1512,   113,   125,   110,   114,  4420,
          1114,  1603,   118,  6441,  1105,  1405,   113,   125,   110,   114,
          4420,  1114,  1263,   118,  6441,  2927,  8127,  1204,   112,   188,
         13936,  4184,  2328, 12909,  2756,  1114,   170,  1607,  1104,  3245,
         11048,  6059,  3402,  1114, 19163,   113,   126,   110,   114,  4420,
          1107,  1103,  1654,  1416,   117, 23731,  1126, 10491, 10653,  6022,
          1104,   121,   119,  5840,  1114,   170,  4573,   110,  6595, 14235,
          1104,   121,   119,  4335,   118,   122,   119,  3993,  1111,  1603,
           118,  6441,  1105,  1126, 10491, 10653,  6022,  1104,   121,   119,
          5729,   113,   121,   119,  1476,   118,   122,   119,  5117,   114,
          1111,  1263,   118,  6441,  2927,  8127,  1204,   112,   188, 13936,
          4184,  2328, 12909,   119,  1861,  2686,  1127,  3836,  1107,  2767,
         18460,  1104,  3324,  4420,  1114,  4550, 21941,   118,   122,  3245,
          7877,  5822, 18574,   117,  8359,  4420,  1114,  4550, 21941,   118,
           123,  3245,  7877,  5822, 18574,   117,  1105,  3081,  4420,  1114,
           191, 19308, 18778,  1183,  1105,   185,  7777, 14824,  1643, 19268,
          1183,  1111,  1241,  1603,   118,  1105,  1263,   118,  6441,  2927,
          8127,  1204,   112,   188, 13936,  4184,  2328, 12909,   119, 11019,
         23315, 11890, 21052,   117,  1103,  2915,  1104, 14938,  1123,  5813,
           117,  1105,  6272,  8160,  1127,  1155,  2628,  1114,  8208, 11040,
          1111,  2927,  8127,  1204,   112,   188, 13936,  4184,  2328, 12909,
           119, 16421,   131,  3245, 11048,  6059,  1111, 26181, 11368,   185,
         15384,  1596, 23449, 14840,  3653,  1110,  1136,   170,  3187,  5318,
          1111,  1719,  1603,   118,  1137,  1263,   118,  6441,  2927,  8127,
          1204,   112,   188, 13936,  4184,  2328, 12909,   119,  1142,  2960,
          1104,  3852,  1206,  3245, 11048,  6059,  1105,  2927,  8127,  1204,
           112,   188, 13936,  4184,  2328, 12909,  5401,  1115,  1231,  2087,
         24796,  1104, 16516,  1513,  1443,  5190,  1110,  1136,  6664,  1106,
          3290,  1103, 13936,  4184, 19911,  1348,   182, 21977,  9275,   119,
           107,   102],
        [  101,   107,  2693,  1103,  4495,  1104,  1103,  9323,  1112,   170,
          2235,  1449,   117,  1412,  4287,  1104,  1103,  1718,  1104,  9323,
           185, 10205,  6944,  2916,   176,  1200,  1306,  3652,   113,   185,
          1403,  6063,   114,  1110,  1677,  1121,  2335,   119,  1303,  1195,
          6858,  1103, 22740,  1104,   185,  1403,  6063,  1120,  1472, 16700,
          5251,   117,  1147, 10348,  4844,  1107,  1103, 13243,  1143, 27408,
          5075,  1104,  1103,  9323,  9712,  1830, 26503,   117,  1105,  1103,
          3735,  1104,  1103,  9712,  1161,  1475,   174, 18965, 15622,  1113,
           185,  1403,  6063,   119,  1103, 15442,  3735,  1104,   185,  1403,
          6063,  1219,  1147, 10348,  1108,  6858,  1118, 13280, 13601,  2728,
          2087,  7535, 12238, 14797,  1113,  2006,   118,  5378,  9323,  9712,
          1830, 26503,  1116,  1105,  1113, 18311, 16274,  4886,   117,  1606,
          9712,  1161,  1475,  1105,  9323,   191, 18384, 16358,  3702, 13791,
         26491,   119,  1229,  1107,  1103,   176,  1200, 14503,  1233,   172,
          4894,  8298,   185,  1403,  6063,  1127,  7982,  1105,  1178,  1512,
           110,  1104,  1172,  1127, 12893,  1118,  9712,  1161,  1475,   117,
          1510,  1562,  1112,   170,  7902, 10005,  1113,  1103,  2765,  2473,
           117,  1378,  3908, 11509,  1891,  1105, 10348,  1107,  1103, 13243,
          1143, 27408,  5075,   185,  1403,  6063,  2888,  1126, 19564, 22740,
           117,  1105,  3078,   110,  7799,  9712,  1161,  1475,   174, 18965,
         15622,   117,  1134,  1108,  7902,  1120,  1103,  5580,  1104,  1103,
         23563,  5674,  7168,   117,  1120,  1103,  3232,  3911,  1206,  8480,
           185,  1403,  6063,   119,  8179,  1104,   185,  1403,  1665, 10348,
          1107,  1103, 13243,  1143, 27408,  5075,  1104,  5871, 12913, 23872,
          1105,  5871,  3080, 13464,  2016,  1406,   118,  1659,  9712,  1830,
         26503,  1116,  7160,   170,  1286,   118,  1268,  1112, 17162, 21027,
           117,  1112, 10348,  1104,  3652,  1755,  1103,   176, 21462,  6163,
         18431,  1108,  1932,  7458,  1106,  1103,  1268,   117,  1897,  1190,
          1103,  1286,   117,  1334,  1104,  1103,  1143, 27408,  5075,   119,
          1167,  5909,   117,  1126,  8179,  1104,  1330,  1372,  1104,  3652,
          1115, 25572,  1194,  1103, 13243,  1143, 27408,  5075,   117,  1103,
          3873,  1596, 18250, 13468,  3652,   117,  3090,   170,  1861, 12629,
          1111,  1103,  1268,  1334,  1104,  1103,  1143, 27408,  5075,   117,
          8783,  1115,  1103,  1940, 24521, 13548,  1104,   185,  1403,  6063,
          1110, 26754,  1118,  1103,  1143, 27408,  5075,  2111,   119,  1412,
          9505,  2194,  1207, 24180,  1154,  1103, 10348, 13548,  1104,   185,
          1403,  6063,  1107,  1103, 13243,  1143, 27408,  5075,   117,  1105,
          5996,   170,  5088,  1206,  9712,  1161,  1475,   117,   185,  1403,
          1665, 10348,  1105,  2765,   118,  2765, 10393,   119,  1292,  9505,
          1336,  8681,  1106,   170,  1618,  4287,  1104,  1103,  6978, 10311,
         10348,  1104,   185,  1403,  6063,  1107,   170, 10644,  1116,   119,
           107,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0]])

Can you see how padding appears in the tokenized text?

Do you notice anything special about the first and last non-padding tokens?

Let’s decode back:

[tokenizer.decode(input_ids) for input_ids in ids]
['[CLS] " background & aims : the contribution of duodeno - gastroesophageal reflux to the development of barrett\'s esophagus has remained an interesting but controversial topic. the present study assessed the risk for barrett\'s esophagus after partial gastrectomy. methods : the data of outpatients from a medicine and gastroenterology clinic who underwent upper gastrointestinal endoscopy for any reason were analyzed in a case - control study. a case population of 650 patients with short - segment and 366 patients with long - segment barrett\'s esophagus was compared in a multivariate logistic regression to a control population of 3047 subjects without barrett\'s esophagus or other types of gastroesophageal reflux disease. results : in the case population, 25 ( 4 % ) patients with short - segment and 15 ( 4 % ) patients with long - segment barrett\'s esophagus presented with a history of gastric surgery compared with 162 ( 5 % ) patients in the control population, yielding an adjusted odds ratio of 0. 89 with a 95 % confidence interval of 0. 54 - 1. 46 for short - segment and an adjusted odds ratio of 0. 71 ( 0. 30 - 1. 72 ) for long - segment barrett\'s esophagus. similar results were obtained in separate analyses of 64 patients with billroth - 1 gastrectomy, 105 patients with billroth - 2 gastrectomy, and 33 patients with vagotomy and pyloroplasty for both short - and long - segment barrett\'s esophagus. caucasian ethnicity, the presence of hiatus hernia, and alcohol consumption were all associated with elevated risks for barrett\'s esophagus. conclusions : gastric surgery for benign peptic ulcer disease is not a risk factor for either short - or long - segment barrett\'s esophagus. this lack of association between gastric surgery and barrett\'s esophagus suggests that reflux of bile without acid is not sufficient to damage the esophageal mucosa. " [SEP]',
 '[CLS] " despite the importance of the chicken as a model system, our understanding of the development of chicken primordial germ cells ( pgcs ) is far from complete. here we characterized the morphology of pgcs at different developmental stages, their migration pattern in the dorsal mesentery of the chicken embryo, and the distribution of the ema1 epitope on pgcs. the spatial distribution of pgcs during their migration was characterized by immunofluorescence on whole - mounted chicken embryos and on paraffin sections, using ema1 and chicken vasa homolog antibodies. while in the germinal crescent pgcs were rounded and only 25 % of them were labeled by ema1, often seen as a concentrated cluster on the cell surface, following extravasation and migration in the dorsal mesentery pgcs acquired an elongated morphology, and 90 % exhibited ema1 epitope, which was concentrated at the tip of the pseudopodia, at the contact sites between neighboring pgcs. examination of pgc migration in the dorsal mesentery of hamburger and hamilton stage 20 - 22 embryos demonstrated a left - right asymmetry, as migration of cells toward the genital ridges was usually restricted to the right, rather than the left, side of the mesentery. moreover, an examination of another group of cells that migrate through the dorsal mesentery, the enteric neural crest cells, revealed a similar preference for the right side of the mesentery, suggesting that the migratory pathway of pgcs is dictated by the mesentery itself. our findings provide new insights into the migration pathway of pgcs in the dorsal mesentery, and suggest a link between ema1, pgc migration and cell - cell interactions. these findings may contribute to a better understanding of the mechanism underlying migration of pgcs in avians. " [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]']

There is a minor cleanliness issue: the abstracts start and close with unneeded quotation marks. We will add a preprocessing step to remove these while batching samples.

Cleaning and batching#

Here we will interface between the Huggingface tools and native Pytorch tools.

from torch.utils.data import Dataset, DataLoader, default_collate
def clean_and_tokenize(text_batch):
    """
    This method demonstrates how you can apply custom preprocessing logic while you load your data. 
    
    It expects a list of plaintext abstracts as input. 
    """
    ## custom preprocessing
    # get rid of unwanted opening/closing quotes
    text_batch = [t[1:-1] for t in text_batch]
    
    ## tokenization
    # we use the huggingface tokenizer as above
    text_batch = tokenizer(text_batch, padding=True, truncation=True, max_length=512)
    
    return text_batch
    
def custom_collate(batch_list):
    """
    This is for use with the pytorch DataLoader class. We use the default collate function
    but add the cleaning and tokenization step. 
    """
    batch = default_collate(batch_list)
    batch['text'] = clean_and_tokenize(batch['text'])
    
    return batch

We can now use this collate function with the Pytorch DataLoader class to load, clean, tokenize and batch our text data. Once we can do this, we’re ready to work on modeling our data.

dl = DataLoader(dataset['train'], batch_size=3, collate_fn = custom_collate)
# Let's look at a batch
batch = next(iter(dl))
batch
{'text': {'input_ids': [[101, 1103, 2853, 1104, 4167, 21943, 4807, 1120, 1103, 2704, 173, 4894, 2883, 118, 15688, 10886, 21039, 113, 173, 4894, 2883, 117, 176, 14170, 1183, 114, 1108, 1771, 1107, 7079, 1112, 1141, 1104, 1103, 3778, 7844, 1104, 4167, 21943, 4807, 1107, 170, 5186, 2704, 1107, 176, 14170, 1183, 119, 173, 4894, 2883, 1108, 1103, 2364, 1104, 21718, 21501, 1183, 1105, 117, 1112, 1216, 117, 1141, 1104, 1103, 1211, 5918, 3057, 6425, 1104, 176, 14170, 1183, 119, 1142, 2440, 2820, 1110, 1145, 7226, 1118, 1103, 2704, 112, 188, 1607, 2111, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 3582, 131, 1195, 3402, 1103, 22760, 1104, 5677, 118, 2747, 12365, 25711, 4043, 10831, 1107, 170, 1372, 1104, 1119, 8031, 12809, 11179, 1200, 185, 7777, 9012, 10594, 1482, 1105, 170, 1372, 1104, 8362, 1394, 21601, 1482, 1105, 10788, 1103, 2398, 1206, 1103, 2915, 1104, 7503, 12365, 25711, 4043, 10831, 1105, 1103, 2781, 1104, 1103, 4010, 11932, 119, 4420, 1105, 4069, 131, 1141, 2937, 1105, 2570, 118, 1300, 1482, 1114, 173, 6834, 3186, 3491, 1465, 113, 4335, 3287, 117, 3102, 2636, 132, 1928, 1425, 1275, 119, 126, 1201, 132, 2079, 125, 118, 1627, 114, 9315, 3245, 8005, 11428, 5005, 131, 4376, 1125, 177, 119, 185, 7777, 9012, 8974, 113, 1955, 2636, 117, 1512, 3287, 114, 117, 1229, 5599, 113, 3413, 2636, 1105, 1955, 3287, 114, 117, 1127, 177, 119, 185, 7777, 9012, 118, 4366, 119, 1155, 14516, 1611, 1127, 7289, 1111, 1103, 2915, 1104, 131, 14247, 22331, 1348, 2765, 12365, 25711, 4043, 10831, 113, 185, 2599, 114, 117, 27799, 5318, 12365, 25711, 4043, 10831, 113, 1191, 1161, 114, 117, 17599, 7301, 8191, 1233, 12365, 25711, 4043, 10831, 117, 21153, 24081, 2858, 27515, 1394, 12365, 25711, 4043, 10831, 117, 1110, 5765, 2765, 12365, 25711, 4043, 10831, 117, 176, 25937, 11787, 1665, 5190, 1260, 8766, 8757, 22948, 2217, 12365, 25711, 4043, 10831, 117, 8050, 23503, 1233, 21284, 12365, 25711, 4043, 10831, 117, 188, 25710, 2386, 118, 4411, 2765, 12365, 25711, 4043, 10831, 132, 3245, 19091, 1179, 117, 185, 8043, 10606, 19790, 170, 117, 185, 8043, 10606, 19790, 172, 1105, 2848, 118, 177, 119, 185, 7777, 9012, 26491, 119, 1103, 1117, 2430, 7810, 1956, 1105, 1103, 190, 11811, 1105, 11019, 2571, 9077, 1127, 1145, 1737, 119, 2686, 131, 1103, 5625, 1104, 5677, 118, 2747, 12365, 25711, 4043, 10831, 1108, 2299, 1107, 4420, 1114, 177, 119, 185, 7777, 9012, 8974, 1190, 1107, 8362, 1394, 21601, 4420, 113, 22572, 1182, 1477, 118, 2774, 185, 133, 119, 1288, 1475, 114, 119, 4418, 3245, 11048, 12365, 25711, 4043, 10831, 1127, 5409, 2299, 131, 1978, 1104, 1103, 4376, 177, 119, 185, 7777, 9012, 118, 3112, 1482, 1127, 185, 2599, 118, 3112, 1105, 1141, 1108, 1191, 1161, 118, 3112, 113, 22572, 1182, 1477, 118, 2774, 185, 134, 119, 1288, 1527, 114, 119, 1103, 2915, 1104, 12365, 25711, 4043, 10831, 1108, 1136, 2628, 1114, 1251, 7300, 1137, 25128, 21631, 17536, 5300, 1104, 3653, 119, 16421, 131, 1412, 2025, 11168, 170, 2398, 1206, 177, 119, 185, 7777, 9012, 8974, 1107, 5153, 1105, 1103, 2915, 1104, 5677, 118, 2747, 12365, 25711, 4043, 10831, 8362, 11192, 13335, 20175, 1114, 1251, 7300, 1137, 25128, 21631, 17536, 5300, 1104, 3653, 119, 1119, 8031, 12809, 11179, 1200, 185, 7777, 9012, 8974, 1107, 5153, 1180, 9887, 1103, 15415, 1104, 7300, 12365, 4060, 13601, 1673, 3245, 19091, 6620, 117, 1105, 120, 1137, 1168, 7300, 12365, 4060, 13601, 1673, 8131, 119, 102], [101, 3582, 131, 1103, 6457, 1104, 1103, 1675, 2025, 1108, 1106, 14133, 1103, 7300, 23891, 1104, 8276, 24928, 7880, 1874, 5822, 18574, 113, 187, 1179, 114, 1114, 24928, 7880, 3484, 118, 22620, 3384, 6059, 113, 183, 3954, 114, 1107, 12770, 4420, 1114, 25813, 1231, 7050, 2765, 1610, 16430, 7903, 113, 187, 19515, 114, 119, 4069, 131, 1103, 3783, 3403, 1108, 1982, 1107, 11030, 4611, 117, 1143, 28054, 2042, 3450, 1200, 117, 1950, 15339, 2598, 2904, 117, 1884, 1732, 18194, 3340, 117, 1105, 1301, 8032, 1513, 6597, 1146, 1106, 1260, 2093, 10615, 1368, 119, 1103, 3594, 3189, 2618, 126, 119, 122, 1105, 1103, 188, 18244, 3594, 7305, 191, 119, 1429, 119, 121, 1127, 1215, 1111, 18460, 119, 1103, 10653, 24576, 113, 1137, 1116, 114, 1105, 1157, 4573, 110, 6595, 14235, 113, 4573, 110, 172, 1182, 114, 1127, 10056, 1111, 7577, 119, 23470, 18460, 1127, 1982, 1359, 1113, 1103, 14601, 2060, 1104, 187, 19515, 119, 2686, 131, 1107, 1703, 117, 1275, 2527, 1114, 1275, 117, 21223, 187, 19515, 4420, 113, 128, 117, 4991, 1568, 5165, 1114, 187, 1179, 1105, 124, 117, 13743, 5165, 1114, 183, 3954, 114, 1127, 2700, 119, 1103, 4528, 1174, 10301, 113, 1137, 134, 122, 119, 4650, 117, 4573, 110, 172, 1182, 134, 122, 119, 1405, 118, 123, 119, 1405, 117, 185, 134, 121, 119, 3135, 1527, 114, 2799, 170, 5409, 2211, 2603, 1104, 4182, 118, 2747, 6209, 1107, 1103, 4420, 5165, 1114, 183, 3954, 3402, 1106, 187, 1179, 119, 1649, 117, 1185, 11435, 1193, 2418, 5408, 1127, 1276, 1107, 1103, 2603, 1104, 14601, 1231, 10182, 21629, 113, 1137, 134, 121, 119, 5731, 117, 4573, 110, 172, 1182, 134, 121, 119, 5486, 118, 122, 119, 5037, 117, 185, 134, 121, 119, 1489, 114, 1105, 13522, 113, 1137, 134, 121, 119, 5539, 117, 4573, 110, 172, 1182, 134, 121, 119, 4062, 118, 122, 119, 5519, 117, 185, 134, 121, 119, 5692, 114, 1206, 1103, 4420, 5165, 1114, 183, 3954, 1105, 187, 1179, 119, 1107, 1901, 117, 1155, 1103, 23470, 18460, 2756, 8080, 2686, 1114, 1103, 2905, 18460, 119, 16421, 131, 183, 3954, 1125, 1185, 5409, 1472, 1121, 187, 1179, 1107, 14601, 1231, 10182, 21629, 1105, 13522, 1111, 25813, 187, 19515, 119, 1649, 117, 1103, 5409, 2211, 2603, 1104, 4182, 118, 2747, 6209, 2726, 1103, 1329, 1104, 183, 3954, 1136, 1178, 1111, 187, 19515, 1114, 14601, 2060, 135, 125, 119, 121, 3975, 1133, 1145, 1111, 14601, 10855, 136, 125, 119, 121, 3975, 3402, 1114, 187, 1179, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}}

Saving code for later#

I’ve pulled the above code into a separate file called dataset.py. This will allow us to reuse the code in future notebooks. Copy the file into your working directory:

wget https://raw.githubusercontent.com/clemsonciti/rcde_workshops/master/pytorch_llm/dataset.py

Let’s briefly look at the usage:

from dataset import PubMedDataset
dataset = PubMedDataset(
    root = "/project/rcde/datasets/pubmed/mesh_50k/splits/", 
    max_tokens = 20,
    tokenizer_model = "dmis-lab/biobert-base-cased-v1.2"
)
Found cached dataset text (/home/dane2/.cache/huggingface/datasets/text/default-cadbbf8acc2e2b5a/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2)
dl_train = dataset.get_dataloader(split='train', batch_size=3) # split can be "train" or "test"
batch = next(iter(dl_train))
batch
{'input_ids': tensor([[  101,  1103,  2853,  1104,  4167, 21943,  4807,  1120,  1103,  2704,
           173,  4894,  2883,   118, 15688, 10886, 21039,   113,   173,   102],
        [  101,  3582,   131,  1195,  3402,  1103, 22760,  1104,  5677,   118,
          2747, 12365, 25711,  4043, 10831,  1107,   170,  1372,  1104,   102],
        [  101,  3582,   131,  1103,  6457,  1104,  1103,  1675,  2025,  1108,
          1106, 14133,  1103,  7300, 23891,  1104,  8276, 24928,  7880,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
dataset.decode_batch(batch['input_ids'])
['[CLS] the department of dermatology at the hospital dresden - friedrichstadt ( d [SEP]',
 '[CLS] background : we compared the prevalence of organ - specific autoantibodies in a group of [SEP]',
 '[CLS] background : the aim of the present study was to compare the clinical efficacy of radical neph [SEP]']