How to Fine-Tune a Pretrained Hugging Face Model
Fine-tuning pretrained Hugging Face models can be challenging, especially for beginners, but it's one of the most effective ways to adapt state-of-the-art NLP models to your own tasks. In this step-by-step guide, we'll walk you through fine-tuning models from easiest to most advanced, starting with DistilBERT for sentiment analysis, moving to BertForTokenClassification for named entity recognition, and finally tackling BertForQuestionAnswering with the SQuAD2 checkpoint. You'll learn how to load pretrained checkpoints, prepare datasets, configure training, and evaluate model performance — all with clear examples designed to make fine-tuning Hugging Face models accessible and practical. This article assumes familiarity with transformer architectures and the Hugging Face ecosystem. If you're new to these concepts, you may want to start with my beginner-friendly Hugging Face course.
Before diving in, there are a few key points worth keeping in mind. Use the comments as a guide —they explain each section and help you track what's happening in the code. Since fine-tuning a Hugging Face model can take a significant amount of time, we'll shuffle the dataset and use a subset of the samples to reduce training time. We will use Hugging Face Trainer and TrainingArguments. Hugging Face models for different tasks require different fine-tuning syntax, since each task uses a different model class with its own inputs and outputs. We will fine-tune a sequence classification, token classification, and question-answering model. You don't need to evaluate your model, but you can learn how to do so with the examples below after training. If the model doesn't perform well, you can make improvements.
You will need the following Python libraries: transformers, torch, scikit-learn, numpy, datasets, seqeval, and evaluate for the examples below. If you haven’t set up Hugging Face and the required libraries yet, check out our installation and setup guide.
We will use the AutoModelForSequenceClassification model with the distilbert/distilbert-base-uncased-finetuned-sst-2-english checkpoint to fine-tune the model. We will use the stanfordnlp/imdb dataset. We'll shuffle the dataset and select a smaller sample to speed up the process.
Fine-Tuning DistilBERT with AutoModelForSequenceClassification (SST-2 Example)
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from sklearn.metrics import precision_recall_fscore_support
import numpy
from datasets import load_dataset
from datasets import Dataset
#1. Import and Load Model + Tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased-finetuned-sst-2-english")
model = AutoModelForSequenceClassification.from_pretrained("distilbert/distilbert-base-uncased-finetuned-sst-2-english")
#2. Prepare Your Dataset
ds_train = load_dataset("stanfordnlp/imdb", split="train").shuffle(seed=42).select(range(100))
ds_test = load_dataset("stanfordnlp/imdb", split="test").shuffle(seed=42).select(range(30))
sent = ds_train["text"]
labels = ds_train["label"]
test_sent = ds_test["text"]
test_labels = ds_test["label"]
train_dataset = Dataset.from_dict({"text": sent, "label": labels})
val_dataset = Dataset.from_dict({"text": test_sent, "label": test_labels})
#3. Tokenize the Dataset
def tokenize(batch):
return tokenizer(batch["text"], truncation=True, padding=True, return_tensors="pt")
train_dataset = train_dataset.map(tokenize, batched=True)
val_dataset = val_dataset.map(tokenize, batched=True)
train_dataset = train_dataset.remove_columns("text").with_format("torch")
val_dataset = val_dataset.remove_columns("text").with_format("torch")
#4. Define Metrics
from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support
#For a deeper understanding of argmax, see the explanation below.
def compute_metrics(pred):
preds = pred.predictions.argmax(-1)
labels = pred.label_ids
return { "accuracy": accuracy_score(labels, preds), "f1": f1_score(labels, preds) }
#5. Set Up Training Arguments
from transformers import TrainingArguments
training_args = TrainingArguments(
logging_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=3,
weight_decay=0.01 )
#6. Initialize Trainer
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
compute_metrics=compute_metrics)
#7. Train the Model
trainer.train()
#8. Evaluate After Training
print(trainer.evaluate()){'eval_loss': 0.09679087996482849, 'eval_accuracy': 0.9666666666666667, 'eval_f1': 0.967741935483871, 'eval_runtime': 0.7008, 'eval_samples_per_second': 42.81, 'eval_steps_per_second': 5.708, 'epoch': 3.0}
argmax(-1) takes the index of the highest value (argmax) along the last dimension of the tensor.
We will train a token classification model using the BertForTokenClassification. The column names for the dataset below are 'id', 'document_id', 'sentence_id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'.
Fine-Tuning BertForTokenClassification for Named Entity Recognition (Example)
import torch
from transformers import BertTokenizerFast, BertForTokenClassification
from datasets import load_dataset
from seqeval.metrics import classification_report
import datasets
# Load model/tokenizer
model_name = "dslim/bert-base-NER"
tokenizer = BertTokenizerFast.from_pretrained(model_name)
model = BertForTokenClassification.from_pretrained(model_name)
id2label = model.config.id2label
label2id = {v: k for k, v in id2label.items()}
# Load dataset
dataset = load_dataset("tomaarsen/conll2003") # Disable caching
dataset2 = load_dataset("tomaarsen/conll2003", split="train").select(range(200))
dataset3 = load_dataset("tomaarsen/conll2003", split="validation").select(range(20))
#Tokenize the Dataset
def preprocess_function(examples):
tokenized_inputs = tokenizer(examples["tokens"], truncation=True, padding=True, is_split_into_words=True,
return_tensors="pt")
labels = []
for i, label in enumerate(examples["ner_tags"]):
word_ids = tokenized_inputs.word_ids(batch_index=i)
previous_word_idx = None
label_ids = []
for word_idx in word_ids:
if word_idx is None:
label_ids.append(-100)
elif word_idx != previous_word_idx:
label_ids.append(label[word_idx])
else:
label_ids.append(-100)
previous_word_idx = word_idx
labels.append(label_ids)
tokenized_inputs["labels"] = labels
return tokenized_inputs
predicted_labels = []
true_labels = []
print(dataset2.column_names)
train_data = dataset2.map(preprocess_function, batched=True, remove_columns=dataset2.column_names)
test_data = dataset3.map(preprocess_function, batched=True, remove_columns=dataset3.column_names)
import numpy as np
def align_predictions(predictions, label_ids):
preds = np.argmax(predictions, axis=2)
batch_size, seq_len = preds.shape
true_labels = []
true_preds = []
for i in range(batch_size):
pred_tags = []
true_tags = []
for j in range(seq_len):
if label_ids[i][j] != -100:
true_tags.append(id2label[label_ids[i][j]])
pred_tags.append(id2label[preds[i][j]])
true_labels.append(true_tags)
true_preds.append(pred_tags)
return true_preds, true_labels
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
import evaluate
seqeval = evaluate.load('seqeval')
def compute_metrics(p):
predictions, labels = p.predictions, p.label_ids
preds, trues = align_predictions(predictions, labels)
print(seqeval.compute(predictions=preds, references=trues))
results = seqeval.compute(predictions=preds, references=trues)
return {
"precision": precision_score(trues, preds),
"recall": recall_score(trues, preds),
"f1": f1_score(trues, preds) }
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="my_awesome_qa_model",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=3,
weight_decay=0.01 )
from transformers import Trainer
trainer = Trainer(
model,
training_args,
train_dataset=train_data,
eval_dataset=test_data,
processing_class=tokenizer )
trainer.train()
results = trainer.evaluate()
print("\nEvaluation results:", results)
# Optional: print detailed classification report
predictions = trainer.predict(test_data)
pred_tags, true_tags = align_predictions(predictions.predictions, predictions.label_ids)
print("\nDetailed Classification Report:\n")
print(classification_report(true_tags, pred_tags))Evaluation results: {'eval_loss': 0.47007593512535095, 'eval_runtime': 0.2464, 'eval_samples_per_second': 81.158, 'eval_steps_per_second': 8.116, 'epoch': 3.0}
Let's fine-tune a question answering model. While fine-tuning a question-answering model can be time-consuming and complex, the process follows the same pattern as with previous models. If you'd like to refresh your memory on how the model works, refer to the example above. We will use the BertForQuestionAnswering model with the deepset/bert-base-cased-squad2 checkpoint.
Fine-Tuning BERT for Question Answering with deepset/bert-base-cased-squad2 (Example)
from datasets import load_dataset, Dataset
from transformers import ( AutoTokenizer, BertForQuestionAnswering, TrainingArguments, Trainer, )
import collections
import evaluate
import numpy as np
#Load small SQuAD v2 dataset
dataset = load_dataset("rajpurkar/squad_v2")
dataset2 = load_dataset("rajpurkar/squad_v2", split="train").select(range(50))
dataset3 = load_dataset("rajpurkar/squad_v2", split="validation").select(range(20))
model_name = "deepset/bert-base-cased-squad2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = BertForQuestionAnswering.from_pretrained(model_name)
context = dataset["train"]["context"][:50]
question = dataset["train"]["question"][:50]
context_v = dataset["validation"]["context"][:20]
question_v = dataset["validation"]["question"][:20]
answers_v = dataset["validation"]["answers"][:20]
inputs = tokenizer(
question,
context,
max_length=100,
truncation="only_second",
stride=50,
return_overflowing_tokens=True,
return_offsets_mapping=True )
answers = dataset["train"][:50]["answers"]
start_positions = []
end_positions = []
for i, offset in enumerate(inputs["offset_mapping"]):
sample_idx = inputs["overflow_to_sample_mapping"][i]
answer = answers[sample_idx]
start_char = answer["answer_start"][0]
end_char = answer["answer_start"][0] + len(answer["text"][0])
sequence_ids = inputs.sequence_ids(i)
#Find the start and end of the context
idx = 0
while sequence_ids[idx] != 1:
idx += 1
context_start = idx
while sequence_ids[idx] == 1:
idx += 1
context_end = idx - 1
#If the answer is not fully inside the context, label is (0, 0). Otherwise it's the start and end token positions.
if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
start_positions.append(0)
end_positions.append(0)
else:
idx = context_start
while idx <= context_end and offset[idx][0] <= start_char:
idx += 1
start_positions.append(idx - 1)
idx = context_end
while idx >= context_start and offset[idx][1] >= end_char:
idx -= 1
end_positions.append(idx + 1)
max_length = 384
stride = 128
def preprocess_training_examples(examples):
questions = [q.strip() for q in examples["question"]]
inputs = tokenizer(
question,
examples["context"],
max_length=max_length,
truncation="only_second",
stride=stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length" )
offset_mapping = inputs.pop("offset_mapping")
sample_map = inputs.pop("overflow_to_sample_mapping")
answers = examples["answers"]
start_positions = []
end_positions = []
for i, offset in enumerate(offset_mapping):
sample_idx = sample_map[i]
answer = answers[sample_idx]
start_char = answer["answer_start"][0]
end_char = answer["answer_start"][0] + len(answer["text"][0])
sequence_ids = inputs.sequence_ids(i)
#Find the start and end of the context
idx = 0
while sequence_ids[idx] != 1:
idx += 1
context_start = idx
while sequence_ids[idx] == 1:
idx += 1
context_end = idx - 1
#If the answer is not fully inside the context, label is (0, 0). Otherwise it's the start and end token positions.
if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
start_positions.append(0)
end_positions.append(0)
else:
idx = context_start
while idx <= context_end and offset[idx][0] <= start_char:
idx += 1
start_positions.append(idx - 1)
idx = context_end
while idx >= context_start and offset[idx][1] >= end_char:
idx -= 1
end_positions.append(idx + 1)
inputs["start_positions"] = start_positions
inputs["end_positions"] = end_positions
return inputs
train_dataset = dataset2.map(
preprocess_training_examples,
batched=True,
remove_columns=dataset2.column_names )
print(len(dataset2), len(train_dataset))
def preprocess_validation_examples(examples):
questions = [q.strip() for q in examples["question"]]
inputs = tokenizer(
question_v,
examples["context"],
max_length=max_length,
truncation="only_second",
stride=stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length" )
sample_map = inputs.pop("overflow_to_sample_mapping")
example_ids = []
for i in range(len(inputs["input_ids"])):
sample_idx = sample_map[i]
example_ids.append(examples["id"][sample_idx])
sequence_ids = inputs.sequence_ids(i)
offset = inputs["offset_mapping"][i]
inputs["offset_mapping"][i] = [ o if sequence_ids[k] == 1 else None for k, o in enumerate(offset) ]
inputs["example_id"] = example_ids
return inputs
validation_dataset = dataset3.map(
preprocess_validation_examples,
batched=True,
remove_columns=dataset3.column_names )
print(len(dataset3), len(validation_dataset))
import collections
import numpy as np
formatted_predictions = []
def postprocess_predictions(examples, features, raw_predictions, tokenizer, n_best_size=20, max_answer_length=30):
all_start_logits, all_end_logits = raw_predictions
example_id_to_index = {k["id"]: i for i, k in enumerate(examples)}
features_per_example = collections.defaultdict(list)
for i, feature in enumerate(features):
features_per_example[example_id_to_index[feature["example_id"]]].append(i)
predictions = collections.OrderedDict()
for example in examples:
example_index = example_id_to_index[example["id"]]
feature_indices = features_per_example[example_index]
min_null_score = None
valid_answers = []
context = example["context"]
for feature_index in feature_indices:
start_logits = all_start_logits[feature_index]
end_logits = all_end_logits[feature_index]
offset_mapping = features[feature_index]["offset_mapping"]
input_ids = features[feature_index]["input_ids"]
cls_index = input_ids.index(tokenizer.cls_token_id)
feature_null_score = start_logits[cls_index] + end_logits[cls_index]
if min_null_score is None or feature_null_score < min_null_score:
min_null_score = feature_null_score
start_indexes = np.argsort(start_logits)[-1: -n_best_size - 1: -1].tolist()
end_indexes = np.argsort(end_logits)[-1: -n_best_size - 1: -1].tolist()
for start_index in start_indexes:
for end_index in end_indexes:
if (
start_index >= len(offset_mapping)
or end_index >= len(offset_mapping)
or offset_mapping[start_index] is None
or offset_mapping[end_index] is None
or end_index < start_index
or (end_index - start_index + 1) > max_answer_length ):
continue
start_char = offset_mapping[start_index][0]
end_char = offset_mapping[end_index][1]
answer_text = context[start_char:end_char]
score = start_logits[start_index] + end_logits[end_index]
valid_answers.append({"text": answer_text, "score": score})
if valid_answers:
best_answer = max(valid_answers, key=lambda x: x["score"])
else:
best_answer = {"text": ""}
if min_null_score is not None and min_null_score > best_answer["score"]:
predictions[example["id"]] = ""
else:
predictions[example["id"]] = best_answer["text"]
formatted_predictions = [{"id": k, "prediction_text": v, "no_answer_probability": 1.0 if answers[i] == "" else 0.0} for k, v in predictions.items()]
references = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]
import evaluate
metric = evaluate.load("squad_v2")
results = metric.compute(predictions=formatted_predictions, references=references)
print("results: ", results)
return formatted_predictions, references
import evaluate
metric = evaluate.load("squad_v2")
def compute_metrics(eval_preds):
features = validation_dataset # tokenized eval dataset
examples = dataset3 # original eval examples (20 samples)
raw_preds = eval_preds.predictions
preds, refs = postprocess_predictions( examples=examples, features=features, raw_predictions=raw_preds, tokenizer=tokenizer )
metrics = metric.compute(predictions=preds, references=refs)
return metrics
from transformers import TrainingArguments
training_args = TrainingArguments(
eval_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=3,
weight_decay=0.01 )
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
tokenizer=tokenizer,
compute_metrics=compute_metrics )
trainer.train()
trainer.save_model("./my_qa_model")
trainer.save_state() # saves optimizer & scheduler state
tokenizer.save_pretrained("./my_qa_model")
result = trainer.evaluate()
print("result: ", result)
#Evaluation
raw_preds = trainer.predict(validation_dataset).predictions
preds, refs = postprocess_predictions(dataset3, validation_dataset, raw_preds, tokenizer)
results = metric.compute(predictions=preds, references=refs)
print("Manual metrics:", results)
Manual metrics: {'exact': 75.0, 'f1': 75.0, 'total': 20, 'HasAns_exact': 90.0, 'HasAns_f1': 90.0, 'HasAns_total': 10, 'NoAns_exact': 60.0, 'NoAns_f1': 60.0, 'NoAns_total': 10, 'best_exact': 80.0, 'best_exact_thresh': 0.0, 'best_f1': 80.0, 'best_f1_thresh': 0.0}
For a more comprehensive learning path, explore our full Hugging Face course covering the ecosystem in depth.