Detecting Dementia Using Instruction-Tuned LLaMA

A cyborg Llama analysing a brain scan.

As part of a project at the University of Stavanger, my fellow Student Kinnan Al Amir and I developed multiple AI models to detect dementia from speech transcriptions. I was tasked with creating two deep learning models: A fine-tuned RoBERTa model and an Instruction tuned LLaMA model. This article covers the LLaMA model.

Why this matters

Having a loved one seemingly lose all memory of oneself is a hurtful experience that the friends and family of over 50 million people have to live with [7]. Dementia affects not only memory but also thinking and hinders patients from living a happy life. But Dementia is not a disease in itself.

Dementia is a syndrome, caused by a variety of diseases, with 60-70% of the cases attributed to Alzheimer's disease [7]. This makes Alzheimer's disease the most common cause of Dementia. One of the early signs of Alzheimer's is a language impairment which is even noticeable in the early stages of the disease [4]. Patients have difficulty finding the right words and are often frustrated with themselves which can lead to anxiety and depression. But these difficulties in expression also give hope for early diagnosis by language analysis.

To help the development of tools that can diagnose Alzheimer's disease early, Saturnino Luz et al. [6]. created the Alzheimer’s Dementia Recognition through Spontaneous Speech (ADReSS) challenge. Part of the challenge is to predict if a patient has Alzheimer's disease based on a speech sample. The challenge provides a dataset with transcriptions of speech samples from patients with Alzheimer's disease and healthy controls.

The Cookie Theft picture

Figure 1: The Cookie Theft picture

The speech samples were taken from patients describing the Cookie Theft picture shown in Figure 1. It is part of the Boston Diagnostic Aphasia Exam [3]. and is used to assess the language capabilities of a patient. We used this dataset to train multiple machine learning and deep learning models. This article will cover how I instruction-tuned the LLaMA 7B model to achieve an accuracy of 75%.

Instruction-tuning LLaMA

LLaMA is a large language model developed by Meta. To fine-tune LLaMA, we need to do the following steps:

  1. Load the model and Tokenizer
  2. Load the dataset
  3. Create Prompts
  4. Tokenize the data
  5. Fine-tune the model using PEFT!

We will import all needed libraries with the following code:


import os
from random import randrange
from functools import partial
import torch
from datasets import load_dataset
from transformers import (AutoModelForCausalLM,
                          AutoTokenizer,
                          BitsAndBytesConfig,
                          HfArgumentParser,
                          Trainer,
                          TrainingArguments,
                          DataCollatorForLanguageModeling,
                          EarlyStoppingCallback,
                          pipeline,
                          logging,
                          set_seed)

import bitsandbytes as bnb
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel, AutoPeftModelForCausalLM
from trl import SFTTrainer

Loading Model and Tokenizer

Before loading the model, we set up the model with a BitsandBytes configuration, enhancing computational efficiency.


def create_bnb_config(load_in_4bit, bnb_4bit_use_double_quant, bnb_4bit_quant_type, bnb_4bit_compute_dtype):
    bnb_config = BitsAndBytesConfig(
        load_in_4bit = load_in_4bit,
        bnb_4bit_use_double_quant = bnb_4bit_use_double_quant,
        bnb_4bit_quant_type = bnb_4bit_quant_type,
        bnb_4bit_compute_dtype = bnb_4bit_compute_dtype,
    )

    return bnb_config

We load the model from Hugging Face using a special function. It configures the model with our bnb settings and optimizes it for the available number of GPUs.


def load_model(model_name, bnb_config):
    # Get number of GPU device and set maximum memory
    n_gpus = torch.cuda.device_count()
    max_memory = f'{40960}MB'

    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config = bnb_config,
        device_map = "auto", # dispatch the model efficiently on the available resources
        max_memory = {i: max_memory for i in range(n_gpus)},
    )

    # Load model tokenizer with the user authentication token
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token = False)

    # Set padding token as EOS token
    tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer

Let us now use the create_bnb_config and load_model functions to load the models and the appropriate tokenizer.


model_name = "beomi/llama-2-ko-7b"

# Activate 4-bit precision base model loading
load_in_4bit = True

# Activate nested quantization for 4-bit base models (double quantization)
bnb_4bit_use_double_quant = True

# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"

# Compute data type for 4-bit base models
bnb_4bit_compute_dtype = torch.bfloat16

bnb_config = create_bnb_config(load_in_4bit, bnb_4bit_use_double_quant, bnb_4bit_quant_type, bnb_4bit_compute_dtype)

model, tokenizer = load_model(model_name, bnb_config)

Loading the Dataset

We want to load the dataset as objects of the dataset class. Hugging Face develops this class which optimizes for Memory Efficiency, and has built-in preprocessing functions amongst other upsides (compared to pandas).


# The instruction dataset to use
dataset_name = ["./Control_db.csv","./Dementia_db.csv", ]

# Load dataset
dataset = load_dataset("csv", data_files = dataset_name, split='train')

print(f'Number of prompts: {len(dataset)}')
print(f'Column names are: {dataset.column_names}')

Creating the tuning Prompts

To fine-tune the LLaMA model, we want to turn all our data-points into long prompts including an instruction, the transcription, and our label mapped to a string. Data with the label 0 will be mapped to 'healthy' and data with the label 1 will be mapped to 'alzheimers'.


def create_prompt_formats(sample, instruction= None):
    label_map = {0: "healthy", 1: "alzheimers"}
    instruction = "The input is a transcription of a patient who could have the alzheimers disease. Based on the transcription respond with 'healthy' or 'alzheimers' according to the patients diagnosis."

    # Initialize static strings for the prompt template
    INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
    INSTRUCTION_KEY = "### Instruction:"
    INPUT_KEY = "Input:"
    RESPONSE_KEY = "### Response:"
    END_KEY = "### End"

    # Combine a prompt with the static strings
    blurb = f"{INTRO_BLURB}"
    instruction = f"{INSTRUCTION_KEY}\n{instruction}"
    input_context = f"{INPUT_KEY}\n{sample['Transcript']}" if sample['Transcript'] else None
    response = f"{RESPONSE_KEY}\n{label_map[sample['Category']]}"
    end = f"{END_KEY}"

    # Create a list of prompt template elements
    parts = [part for part in [blurb, instruction, input_context, response, end] if part]

    # Join prompt template elements into a single string to create the prompt template
    formatted_prompt = "\n\n".join(parts)

    # Store the formatted prompt template in a new key “text"
    sample["text"] = formatted_prompt

    return sample

We can try out the function by running the following snipped:


create_prompt_formats(dataset[randrange(len(dataset))])

This will give us the output:


{'Language': 'eng',
 'Data': 'Pitt',
 'Participant': 'PAR',
 'Age': 61,
 'Gender': 'female',
 'Diagnosis': 'Control',
 'Category': 0,
 'mmse': 30.0,
 'Filename': 'S033',
 'Transcript': " mhm . well the water's running over on the floor . &uh the chair [: stool] [* s:r] is tilting . the boy is into the cookie jar . and his sister is reaching for a cookie . the mother's drying dishes . &um do you want action or just want anything I see ? okay . mhm .",
 'text': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nThe input is a transcription of a patient who could have the alzheimers disease. Based on the transcription respond with 'healthy' or 'alzheimers' according to the patients diagnosis.\n\nInput:\n mhm . well the water's running over on the floor . &uh the chair [: stool] [* s:r] is tilting . the boy is into the cookie jar . and his sister is reaching for a cookie . the mother's drying dishes . &um do you want action or just want anything I see ? okay . mhm .\n\n### Response:\nhealthy\n\n### End"}

We can see that a new column containing our newly constructed prompt was added.

Tokenizing the data

Before tokenizing, we have to find the maximum token length from the model configuration.


def get_max_length(model):
    # Pull model configuration
    conf = model.config
    # Initialize a "max_length" variable to store maximum sequence length as null
    max_length = None
    # Find maximum sequence length in the model configuration and save it in "max_length" if found
    for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:
        max_length = getattr(model.config, length_setting, None)
        if max_length:
            print(f"Found max lenth: {max_length}")
            break
        # Set "max_length" to 1024 (default value) if maximum sequence length is not found in the model configuration
        if not max_length:
            max_length = 1024
            print(f"Using default max length: {max_length}")
    return max_length

We also prepare a function for tokenizing.


def preprocess_batch(batch, tokenizer, max_length):
    return tokenizer(
        batch["text"],
        max_length = max_length,
        truncation = True,
    )

At this point, we have prepared our data and functions for promt generation and tokenization. We are now ready to preprocess our dataset.

The following function takes in our dataset and returns the tokenized data.


def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int, seed, dataset: str):
    # Add prompt to each sample
    print("Preprocessing dataset…")
    dataset = dataset.map(create_prompt_formats)

    # Apply preprocessing to each batch of the dataset & and remove "instruction", "input", "output", and "text" fields
    _preprocessing_function = partial(preprocess_batch, max_length = max_length, tokenizer = tokenizer)
    dataset = dataset.map(
        _preprocessing_function,
        batched = True,
        remove_columns = ['Language', 'Data','Participant','Age','Gender','Diagnosis','Category','mmse','Filename','Transcript','text'],
    )

    # Filter out samples that have "input_ids" exceeding "max_length"
    dataset = dataset.filter(lambda sample: len(sample["input_ids"]) < max_length)

    # Shuffle dataset
    dataset = dataset.shuffle(seed = seed)

    return dataset

We call it like this:


# Random seed
seed = 33

max_length = get_max_length(model)
preprocessed_dataset = preprocess_dataset(tokenizer, max_length, seed, dataset)

Fine-Tuning the model using PEFT

For fine-tuning our LLaMA model efficiently, we have implemented Parameter-Efficient Fine-Tuning (PEFT) techniques, focusing on Low-Rank Adaptation (LoRA). The fine-tuning process is aimed at adjusting a small subset of the model's parameters to achieve significant improvements without the need for extensive computational resources.

Initializing PEFT Configuration

We configure our model for LoRA by initializing the necessary parameters. This step ensures that only a specific portion of the model parameters are adjusted during training, making the process resource-efficient.


def create_peft_config(r, lora_alpha, target_modules, lora_dropout, bias, task_type):
    config = LoraConfig(
        r = r,
        lora_alpha = lora_alpha,
        target_modules = target_modules,
        lora_dropout = lora_dropout,
        bias = bias,
        task_type = task_type,
    )

    return config


def find_all_linear_names(model):
    """
    Find modules to apply LoRA to.

    :param model: PEFT model
    """

    cls = bnb.nn.Linear4bit
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

        if 'lm_head' in lora_module_names:
            lora_module_names.remove('lm_head')
            print(f"LoRA module names: {list(lora_module_names)}")
    return list(lora_module_names)


def print_trainable_parameters(model, use_4bit = False):
    trainable_params = 0
    all_param = 0

    for _, param in model.named_parameters():
        num_params = param.numel()
        if num_params == 0 and hasattr(param, "ds_numel"):
            num_params = param.ds_numel
        all_param += num_params
        if param.requires_grad:
            trainable_params += num_params

        if use_4bit:
            trainable_params /= 2

    print(f"All Parameters: {all_param:,d} || Trainable Parameters: {trainable_params:,d} || Trainable Parameters %: {100 * trainable_params / all_param}")

Fine-tuning Pre-trained Model

With our LoRA settings configured, we begin the fine-tuning process. The training arguments specify the number of epochs, learning rate, and other hyperparameters crucial for effective learning.


def fine_tune(model, tokenizer, dataset, lora_r, lora_alpha, lora_dropout, bias, task_type,
              per_device_train_batch_size, gradient_accumulation_steps, warmup_steps, max_steps,
              learning_rate, fp16, logging_steps, output_dir, optim):
    # Enable gradient checkpointing to reduce memory usage during fine-tuning
    model.gradient_checkpointing_enable()

    # Prepare the model for training
    model = prepare_model_for_kbit_training(model)

    # Get LoRA module names
    target_modules = find_all_linear_names(model)

    # Create PEFT configuration for these modules and wrap the model to PEFT
    peft_config = create_peft_config(lora_r, lora_alpha, target_modules, lora_dropout, bias, task_type)
    model = get_peft_model(model, peft_config)

    # Print information about the percentage of trainable parameters
    print_trainable_parameters(model)

    # Training parameters
    trainer = Trainer(
        model = model,
        train_dataset = dataset,
        args = TrainingArguments(
            per_device_train_batch_size = per_device_train_batch_size,
            gradient_accumulation_steps = gradient_accumulation_steps,
            warmup_steps = warmup_steps,
            max_steps = max_steps,
            learning_rate = learning_rate,
            fp16 = fp16,
            logging_steps = logging_steps,
            output_dir = output_dir,
            optim = optim,
        ),
        data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
    )

    model.config.use_cache = False

    do_train = True

    # Launch training and log metrics
    print("Training…")

    # if do_train:
    train_result = trainer.train()
    metrics = train_result.metrics
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()
    print(metrics)

    # Save model
    print("Saving last checkpoint of the model…")
    os.makedirs(output_dir, exist_ok = True)
    trainer.model.save_pretrained(output_dir)

    # Free memory for merging weights
    del model
    del trainer
    torch.cuda.empty_cache()

Training the Model

Finally, we can train the model. You can see our parameters in the following code snippet:


# LoRA attention dimension
lora_r = 16

# Alpha parameter for LoRA scaling
lora_alpha = 64

# Dropout probability for LoRA layers
lora_dropout = 0.1

# Bias
bias = "none"

# Task type
task_type = "CAUSAL_LM"


# Output directory where the model predictions and checkpoints will be stored
output_dir = "./results"

# Batch size per GPU for training
per_device_train_batch_size = 1

# Number of update steps to accumulate the gradients for
gradient_accumulation_steps = 4

# Initial learning rate (AdamW optimizer)
learning_rate = 2e-4

# Optimizer to use
optim = "paged_adamw_32bit"

# Number of training steps (overrides num_train_epochs)
max_steps = 1000

# Linear warmup steps from 0 to learning_rate
warmup_steps = 2

# Enable fp16/bf16 training (set bf16 to True with an A100)
fp16 = True

# Log every X updates steps
logging_steps = 1

fine_tune(model,
tokenizer,
preprocessed_dataset,
lora_r,
lora_alpha,
lora_dropout,
bias,
task_type,
per_device_train_batch_size,
gradient_accumulation_steps,
warmup_steps,
max_steps,
learning_rate,
fp16,
logging_steps,
output_dir,
optim)

Hooray! We have now successfully fine-tuned the model! All that is left is to load the model and use it to predict our test data.

Testing the model

We can load the pre-trained model which was saved in the previous step.


# Load fine-tuned weights
output_dir = "./results"

model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map = "auto", torch_dtype = torch.bfloat16)

We used a similar technique to create the test prompts but this time the string stops after "### Response:".


sequences = pipeline(
    queries,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id,
    max_new_tokens=10,
    early_stopping=True,
    # do_sample=True,
)

To extract the response, we use the following function:


def extract_responses(sequences):
    responses = []
    for sequence in sequences:
        for item in sequence:
            # Split the text to find the part after "### Response:\n"
            parts = item['generated_text'].split("### Response:\n")
            if len(parts) > 1:
                # Further split to isolate the response before "\n\n### End"
                response_part = parts[1].split("\n\n### End")[0]
                responses.append(response_part.strip())
    return responses

responses = extract_responses(sequences)
print(responses)

Finally, with some more computations, we can create a confusion matrix:

Confusion matrix for the test data.

Accuracy matrix for the test data.

Closing thoughts

In this project, we were tasked to predict dementia using Large Language Models. We chose to try instruction tuning to show an alternative method to top-layer tuning. When we first started tuning this model, it wasn't even able to only generate 'healthy' and 'dementia' as the output. After increasing the training time to 1000 steps, we finally achieved the above output. What will happen if we increase it even more?