Content from Overview


Last updated on 2024-06-19 | Edit this page

Overview

Questions

  • TODO

Objectives

  • TODO

Introduction


Content from Preparing to train a model


Last updated on 2024-07-31 | Edit this page

Overview

Questions

  • For what prediction tasks is machine learning an appropriate tool?
  • How can inappropriate target variable choice lead to suboptimal outcomes in a machine learning pipeline?
  • What forms of “bias” can occur in machine learning, and where do these biases come from?

Objectives

  • Judge what tasks are appropriate for machine learning
  • Understand why the choice of prediction task / target variable is important.
  • Describe how bias can appear in training data and algorithms.

Choosing appropriate tasks


Machine learning is a rapidly advancing, powerful technology that is helping to drive innovation. Before embarking on a machine learning project, we need to consider the task carefully. Many machine learning efforts are not solving problems that need to be solved. Or, the problem may be valid, but the machine learning approach makes incorrect assumptions and fails to solve the problem effectively. Worse, many applications of machine learning are not for the public good.

We will start by considering the NIH Guiding Principles for Ethical Research, which provide a useful set of considerations for any project.

Challenge

Take a look at the NIH Guiding Principles for Ethical Research.

What are the main principles?

A summary of the principles is listed below:

  • Social and clinical value: Does the social or clinical value of developing and implementing the model outweigh the risk and burden of the people involved?
  • Scientific validity: Once created, will the model provide valid, meaningful outputs?
  • Fair subject selection: Are the people who contribute and benefit from the model selected fairly, and not through vulnerability, privilege, or other unrelated factors?
  • Favorable risk-benefit ratio: Do the potential benefits of of developing and implementing the model outweigh the risks?
  • Independent review: Has the project been reviewed by someone independent of the project, and has an Institutional Review Board (IRB) been approached where appropriate?
  • Informed consent: Are participants whose data contributes to development and implementation of the model, as well as downstream recipients of the model, kept informed?
  • Respect for potential and enrolled subjects: Is the privacy of participants respected and are steps taken to continuously monitor the effect of the model on downstream participants?

AI tasks are often most controversial when they involve human subjects, and especially visual representations of people. We’ll discuss two case studies that use people’s faces as a prediction tool, and discuss whether these uses of AI are appropriate.

Case study 1: Physiognomy

In 2019, Nature Medicine published a paper that describes a model that can identify genetic disorders from a photograph of a patient’s face. The abstract of the paper is copied below:

Syndromic genetic conditions, in aggregate, affect 8% of the population. Many syndromes have recognizable facial features that are highly informative to clinical geneticists. Recent studies show that facial analysis technologies measured up to the capabilities of expert clinicians in syndrome identification. However, these technologies identified only a few disease phenotypes, limiting their role in clinical settings, where hundreds of diagnoses must be considered. Here we present a facial image analysis framework, DeepGestalt, using computer vision and deep-learning algorithms, that quantifies similarities to hundreds of syndromes.

DeepGestalt outperformed clinicians in three initial experiments, two with the goal of distinguishing subjects with a target syndrome from other syndromes, and one of separating different genetic sub-types in Noonan syndrome. On the final experiment reflecting a real clinical setting problem, DeepGestalt achieved 91% top-10 accuracy in identifying the correct syndrome on 502 different images. The model was trained on a dataset of over 17,000 images representing more than 200 syndromes, curated through a community-driven phenotyping platform. DeepGestalt potentially adds considerable value to phenotypic evaluations in clinical genetics, genetic testing, research and precision medicine.

  • What is the proposed value of the algorithm?
  • What are the potential risks?
  • Are you supportive of this kind of research?
  • What safeguards, if any, would you want to be used when developing and using this algorithm?

Media reports about this paper were largely positive, e.g., reporting that clinicians are excited about the new technology.

Case study 2:

There is a long history of physiognomy, the “science” of trying to read someone’s character from their face. With the advent of machine learning, this discredited area of research has made a comeback. There have been numerous studies attempting to guess characteristics such as trustworthness, criminality, and political and sexual orientation.

In 2018, for example, researchers suggested that neural networks could be used to detect sexual orientation from facial images. The abstract is copied below:

We show that faces contain much more information about sexual orientation than can be perceived and interpreted by the human brain. We used deep neural networks to extract features from 35,326 facial images. These features were entered into a logistic regression aimed at classifying sexual orientation. Given a single facial image, a classifier could correctly distinguish between gay and heterosexual men in 81% of cases, and in 74% of cases for women. Human judges achieved much lower accuracy: 61% for men and 54% for women. The accuracy of the algorithm increased to 91% and 83%, respectively, given five facial images per person.

Facial features employed by the classifier included both fixed (e.g., nose shape) and transient facial features (e.g., grooming style). Consistent with the prenatal hormone theory of sexual orientation, gay men and women tended to have gender-atypical facial morphology, expression, and grooming styles. Prediction models aimed at gender alone allowed for detecting gay males with 57% accuracy and gay females with 58% accuracy. Those findings advance our understanding of the origins of sexual orientation and the limits of human perception. Additionally, given that companies and governments are increasingly using computer vision algorithms to detect people’s intimate traits, our findings expose a threat to the privacy and safety of gay men and women.

  • What is the proposed value of the algorithm?
  • What are the potential risks?
  • Are you supportive of this kind of research?
  • What distinguishes this use of AI from the use of AI described in Case Study 1?

Media reports of this algorithm were largely negative, with a Scientific American article highlighting the connections to physiognomy and raising concern over government use of these algorithms:

This is precisely the kind of “scientific” claim that can motivate repressive governments to apply AI algorithms to images of their citizens. And what is it to stop them from “reading” intelligence, political orientation and criminal inclinations from these images?

Choosing the outcome variable


Sometimes, choosing the outcome variable is easy: for instance, when building a model to predict how warm it will be out tomorrow, the temperature can be the outcome variable because it’s measurable (i.e., you know what temperature it was yesterday and today) and your predictions won’t cause a feedback loop (e.g., given a set of past weather data, the weather next Monday won’t change based on what your model predicts tomorrow’s temperature to be).

By contrast, sometimes it’s not possible to measure the target prediction subject directly, and sometimes predictions can cause feedback loops.

Case Study: Proxy variables

Consider the scenario described in the challenge below.

Challenge

Suppose that you work for a hospital and are asked to build a model to predict which patients are high-risk and need extra care to prevent negative health outcomes.

Discuss the following with a partner or small group: 1. What is the goal target variable? 2. What are challenges in measuring the target variable in the training data (i.e., former patients)? 3. Are there other variables that are easier to measure, but can approximate the target variable, that could serve as proxies? 3. How do social inequities interplay with the value of the target variable versus the value of the proxies?

The “challenge” scenario is not hypothetical: A well-known study by Obermeyer et al. analyzed an algorithm that hospitals used to assign patients risk scores for various conditions. The algorithm had access to various patient data, such as demographics (e.g., age and sex), the number of chronic conditions, insurance type, diagnoses, and medical costs. The algorithm did not have access to the patient’s race. The patient risk score determined the level of care the patient should receive, with higher-risk patients receiving additional care.

Ideally, the target variable would be health needs, but this can be challenging to measure: how do you compare the severity of two different conditions? Do you count chronic and acute conditions equally? In the system described by Obermeyer et al., the hospital decided to use health-care costs as a proxy for health needs, perhaps reasoning that this data is at least standardized across patients and doctors.

However, Obermeyer et al. reveal that the algorithm is biased against Black patients. That is, if there are two individuals – one white and one Black – with equal health, the algorithm tends to assign a higher risk score to the white patient, thus giving them access to higher care quality. The authors blame the choice of proxy variable for the racial disparities.

The authors go on to describe how, due to how health-care access is structured in the US, richer patients have more healthcare expenses, even if they are equally (un)healthy to a lower-income patient. The richer patients are also more likely to be white.

Consider the following:

  • How could the algorithm developers have caught this problem earlier?
  • Is this a technical mistake or a process-based mistake? Why?

Case study: Feedback loop

Consider social media, like Instagram or TikTok’s “for you page” or Facebook or Twitter’s newsfeed. The algorithms that determine what to show are complex (and proprietary!) but a large part of the algorithms’ objective is engagement: the number of clicks, views, or re-posts. For instance, this focus on engagement can create an “echo chamber” where individual users solely see content that aligns with their political ideology, thereby maximizing the positive engagement with each post. But the impact of social media feedback loops spreads beyond politics: researchers have explored how similar feedback loops exist for mental health conditions such as eating disorders. If someone finds themselves in this area of social media, it’s likely because they have, or have risk factors for, an eating disorder, and seeing pro-eating disorder content can drive engagement, but ultimately be very bad for mental health.

Consider the following questions:

  • Why do social media companies optimize for engagement?
  • What would be an alternative optimization target? How would the outcomes differ, both for users and for the companies’ profits?

Understanding bias


The term “bias” is overloaded, and can have the following definitions:

  • (Statistical) bias: the tendency of an algorithm to produce one solution over another, even though some alternatives may be just as good, or better. Statistical bias can have multiple sources, which we’ll get into below.
  • (Social) bias: outcomes are unfair to one or more social groups. Social bias can be the result of statistical bias (i.e., an algorithm giving preferential treatment to one social group over others), but can also occur outside of a machine learning context.

Sources of statistical bias

Algorithmic bias

Algorithmic bias is the tendency of an algorithm to favor one solution over another. Algorithmic bias is not always bad, and may sometimes be encoded for by algorithm developers. For instance, linear regression with L0-regularization displays algorithmic bias towards sparse classifiers (i.e., classifiers where most weights are 0). This bias may be desirable in settings where human interpretability is important.

But algorithmic bias can also occur unintentionally: for instance, if there is data bias (described below), this may lead algorithm developers to select an algorithm that is ill-suited to underrepresented groups. Then, even if the data bias is rectified, sticking with the original algorithm choice may not fix biased outcomes.

Data bias:

Data bias is when the available training data is not accurate or representative of the target population. Data bias is extremely common (it’s often hard to collect perfectly-representative, and perfectly-accurate data), and care arise in multiple ways:

  • Measurement error - if a tool is not well calibrated, measurements taken by that tool won’t be accurate. Likewise, human biases can lead to measurement error, for instance, if people systematically over-report their height on dating apps, or if doctors do not believe patient’s self-reports of their pain levels.
  • Response bias - for instance, when conducting a survey about customer satisfaction, customers who had very positive or very negative experiences may be more likely to respond.
  • Representation bias - the data is not well representative of the whole population. For instance, doing clinical trials primarily on white men means that women and other races are not well represented in data.

Through the rest of this lesson, if we use the term “bias” without any additional context, we will be referring to social bias that stems from statistical bias.

Case Study

With a partner or small group, choose one of the three case study options. Read or watch individually, then discuss as a group how bias manifested in the training data, and what strategies could correct for it.

After the discussion, share with the whole workshop what you discussed.

  1. Predictive policing
  2. Facial recognition (video, 5 min.)
  3. Amazon hiring tool

Key Points

  • Some tasks are not appropriate for machine learning due to ethical concerns.
  • Machine learning tasks should have a valid prediction target that maps clearly to the real-world goal.
  • Training data can be biased due to societal inequities, errors in the data collection process, and lack of attention to careful sampling practices.
  • “Bias” also refers to statistical bias, and certain algorithms can be biased towards some solutions.

Content from Scientific validity in the modeling process


Last updated on 2024-06-19 | Edit this page

Overview

Questions

  • What impact does overfitting and underfitting have on model performance?
  • What is data leakage?

Objectives

  • Implement at least two types of machine learning models in Python.
  • Describe the risks of, identify, and understand mitigation steps for overfitting and underfitting.
  • Understand why data leakage is harmful to scientific validity and how it can appear in machine learning pipelines.

Key Points

  • Overfitting is characterized by worse performance on the test set than on the train set and can be fixed by switching to a simpler model architecture or by adding regularization.
  • Underfitting is characterized by poor performance on both the training and test datasets. It can be fixed by collecting more training data, switching to a more complex model architecture, or improving feature quality.
  • Data leakage occurs when the model has access to the test data during training and results in overconfidence in the model’s performance.

Content from Model evaluation and fairness


Last updated on 2024-06-19 | Edit this page

Overview

Questions

  • How do we define fairness and bias in machine learning outcomes?
  • What types of bias and unfairness can occur in generative AI?
  • How can we improve the fairness of machine learning models?

Objectives

  • Reason about model performance through standard evaluation metrics.
  • Understand and distinguish between various notions of fairness in machine learning.
  • Describe and implement two different ways of modifying the machine learning modeling process to improve the fairness of a model.

Accuracy metrics


Stakeholders often want to know the accuracy of a machine learning model – what percent of predictions are correct? Accuracy can be decomposed into further metrics: e.g., in a binary prediction setting, recall (the fraction of positive samples that are classified correctly) and precision (the fraction of samples classified as positive that actually are positive) are commonly-used metrics.

Suppose we have a model that performs binary classification (+, -) on a test dataset of 1000 samples (let \(n\)=1000). A confusion matrix defines how many predictions we make in each of four quadrants: true positive with positive prediction (++), true positive with negative prediction (+-), true negative with positive prediction (-+), and true negative with negative prediction (–).

True + True -
Predicted + 300 80
Predicted - 25 595

So, for instance, 80 samples have a true class of + but get predicted as members of -.

We can compute the following metrics:

  • Accuracy: What fraction of predictions are correct?
    • (300 + 595) / 100 = 0.895
    • Accuracy is 89.5%
  • Precision: What fraction of predicted positives are true positives?
    • 300 / (300 + 80) = 0.789
    • Precision is 78.9%
  • Recall: What fraction of true positives are classified as positive?
    • 300 / (300 + 25) = 0.923
    • Recall is 92.3%

Callout

We’ve discussed binary classification but for other types of tasks there are different metrics. For example,

  • Multi-class problems often use Top-K accuracy, a metric of how often the true response appears in their top-K guesses.
  • Regression tasks often use the Area Under the ROC curve (AUC ROC) as a measure of how well the classifier performs at different thresholds.

What accuracy metric to use?

Different accuracy metrics may be more relevant in different situations. Discuss with a partner or small groups whether precision, recall, or some combination of the two is most relevant in the following prediction tasks:

  1. Deciding what patients are high risk for a disease and who should get additional low-cost screening.
  2. Deciding what patients are high risk for a disease and should start taking medication to lower the disease risk. The medication is expensive and can have unpleasant side effects.
  1. It is best if all patients who need the screening get it, and there is little downside for doing screenings unnecessarily because the screening costs are low. Thus, a high recall score is optimal.

  2. Given the costs and side effects of the medicine, we do not want patients not at risk for the disease to take the medication. So, a high precision score is ideal.

How do we measure fairness?


What does it mean for a machine learning model to be fair or unbiased? There is no single definition of fairness, and we can talk about fairness at several levels (ranging from training data, to model internals, to how a model is deployed in practice). Similarly, bias is often used as a catch-all term for any behavior that we think is unfair. Even though there is no tidy definition of unfairness or bias, we can use aggregate model outputs to gain an overall understanding of how models behave with respect to different demographic groups – an approach called group fairness.

In general, if there are no differences between groups in the real world (e.g., if we lived in a utopia with no racial or gender gaps), achieving fairness is easy. But, in practice, in many social settings where prediction tools are used, there are differences between groups, e.g., due to historical and current discrimination.

For instance, in a loan prediction setting in the United States, the average white applicant may be better positioned to repay a loan than the average Black applicant due to differences in generational wealth, education opportunities, and other factors stemming from anti-Black racism. Suppose that a bank uses a machine learning model to decide who gets a loan. Suppose that 50% of white applicants are granted a loan, with a precision of 90% and a recall of 70% – in other words, 90% of white people granted loans end up repaying them, and 70% of all people who would have repaid the loan, if given the opportunity, get the loan. Consider the following scenarios:

  • (Demographic parity) We give loans to 50% of Black applicants in a way that maximizes overall accuracy
  • (Equalized odds) We give loans to X% of Black applicants, where X is chosen to maximize accuracy subject to keeping precision equal to 90%.
  • (Group level calibration) We give loans to X% of Black applicants, where X is chosen to maximize accuracy while keeping recall equal to 70%.

There are many notions of statistical group fairness, but most boil down to one of the three above options: demographic parity, equalized odds, and group-level calibration. All three are forms of distributional (or outcome) fairness. Another dimension, though, is procedural fairness: whether decisions are made in a just way, regardless of final outcomes. Procedural fairness contains many facets, but one way to operationalize it is to consider individual fairness (also called counterfactual fairness), which was suggested in 2012 by Dwork et al. as a way to ensure that “similar individuals [are treated] similarly”. For instance, if two individuals differ only on their race or gender, they should receive the same outcome from an algorithm that decides whether to approve a loan application.

In practice, it’s hard to use individual fairness because defining a complete set of rules about when two individuals are sufficiently “similar” is challenging.

Matching fairness terminology with definitions

Match the following types of formal fairness with their definitions. (A) Individual fairness, (B) Equalized odds, (C) Demographic parity, and (D) Group-level calibration

  1. The model is equally accurate across all demographic groups.
  2. Different demographic groups have the same true positive rates and false positive rates.
  3. Similar people are treated similarly.
  4. People from different demographic groups receive each outcome at the same rate.

A - 3, B - 2, C - 4, D - 1

But some types of unfairness cannot be directly measured by group-level statistical data. In particular, generative AI opens up new opportunities for bias and unfairness. Bias can occur through representational harms (e.g., creating content that over-represents one population subgroup at the expense of another), or through stereotypes (e.g., creating content that reinforces real-world stereotypes about a group of people). We’ll discuss some specific examples of bias in generative models next.

Fairness in generative AI


Generative models learn from statistical patterns in real-world data. These statistical patterns reflect instances of bias in real-world data - what data is available on the internet, what stereotypes does it reinforce, and what forms of representation are missing?

Natural language

One set of social stereotypes that large AI models can learn is gender based. For instance, certain occupations are associated with men, and others with women. For instance, in the U.S., doctors are historically and stereotypically usually men.

In 2016, Caliskan et al. showed that machine translation systems exhibit gender bias, for instance, by reverting to stereotypical gendered pronouns in ambiguous translations, like in Turkish – a language without gendered pronouns – to English.

In response, Google tweaked their translator algorithms to identify and correct for gender stereotypes in Turkish and several other widely-spoken languages. So when we repeat a similar experiment today, we get the following output:

Screenshot of Google Translate output. The English sentence "The doctor is on her lunch break" is translated to Turkish, and then the Turkish output is translated back to English as either "The doctor is on his lunch break" or "The doctor is on his lunch break".
Turkish Google Translate example (screenshot from 1/9/2024)

But for other, less widely-spoken languages, the original problem persists:

Screenshot of Google Translate output. The English sentence "The doctor is on her lunch break" is translated to Norwegian, and then the Norwegian output is translated back to English as "The doctor is on his lunch break".
Norwegian Google Translate example (screenshot from 1/9/2024)

We’re not trying to slander Google Translate here – the translation, without additional context, is ambiguous. And even if they extended the existing solution to Norwegian and other languages, the underlying problem (stereotypes in the training data) still exists. And with generative AI such as ChatGPT, the problem can be even more pernicious.

Red-teaming large language models

In cybersecurity, “red-teaming” is when well-intentioned people think like a hacker in order to make a system safer. In the context of Large Language Models (LLMs), red-teaming is used to try to get LLMs to output offensive, inaccurate, or unsafe content, with the goal of understanding the limitations of the LLM.

Try out red-teaming with ChatGPT or another LLM. Specifically, can you construct a prompt that causes the LLM to output stereotypes? Here are some example prompts, but feel free to get creative!

  • “Tell me a story about a doctor” (or other profession with gender)

  • If you speak a language other than English, how does are ambiguous gendered pronouns handled? For instance, try the prompt “Translate ‘The doctor is here’ to Spanish”. Is a masculine or feminine pronoun used for the doctor in Spanish?

If you use LLMs in your research, consider whether any of these issues are likely to be present for your use cases. If you do not use LLMs in your research, consider how these biases can affect downstream uses of the LLM’s output.

Most publicly-available LLM providers set up guardrails to avoid propagating biases present in their training data. For instance, as of the time of this writing (January 2024), the first suggested prompt, “Tell me a story about a doctor,” consistently creates a story about a woman doctor. Similarly, substituting other professions that have strong associations with men for “doctor” (e.g., “electrical engineer,” “garbage collector,” and “US President”) yield stories with female or gender-neutral names and pronouns.

Discussing other fairness issues

If you use LLMs in your research, consider whether any of these issues are likely to be present for your use cases. Share your thoughts in small groups with other workshop participants.

Image generation

The same problems that language modeling face also affect image generation. Consider, for instance, Melon et al. developed an algorithm called Pulse that can convert blurry images to higher resolution. But, biases were quickly unearthed and shared via social media.

Challenge

Who is shown in this blurred picture? blurry image of Barack Obama

While the picture is of Barack Obama, the upsampled image shows a white face. Unblurred version of the pixelated picture of Obama. Instead of showing Obama, it shows a white man.

You can try the model here.

Menon and colleagues subsequently updated their paper to discuss this issue of bias. They assert that the problems inherent in the PULSE model are largely a result of the underlying StyleGAN model, which they had used in their work.

Overall, it seems that sampling from StyleGAN yields white faces much more frequently than faces of people of color … This bias extends to any downstream application of StyleGAN, including the implementation of PULSE using StyleGAN.

Results indicate a racial bias among the generated pictures, with close to three-fourths (72.6%) of the pictures representing White people. Asian (13.8%) and Black (10.1%) are considerably less frequent, while Indians represent only a minor fraction of the pictures (3.4%).

These remarks get at a central issue: biases in any building block of a system (data, base models, etc.) get propagated forwards. In generative AI, such as text-to-image systems, this can result in representational harms, as documented by Bianchi et al. Fixing these issues of bias is still an active area of research. One important step is to be careful in data collection, and try to get a balanced dataset that does not contain harmful stereotypes. But large language models use massive training datasets, so it is not possible to manually verify data quality. Instead, researchers use heuristic approaches to improve data quality, and then rely on various techniques to improve models’ fairness, which we discuss next.

Improving fairness of models


Model developers frequently try to improve the fairness of there model by intervening at one of three stages: pre-processing, in-processing, or post-processing. We’ll cover techniques within each of these paradigms in turn.

We start, though, by discussing why removing the sensitive attribute(s) is not sufficient. Consider the task of deciding which loan applicants are funded. Suppose we are concerned with racial bias in the model outputs. If we remove race from the set of attributes available to the model, the model cannot make overly racist decisions. However, it could instead make decisions based on zip code, which in the US is a very good proxy for race.

Can we simply remove all proxy variables? We could likely remove zip code, if we cannot identify a causal relationship between where someone lives and whether they will be able to repay a loan. But what about an attribute like educational achievement? Someone with a college degree (compared with someone with, say, less than a high school degree) has better employment opportunities and therefore might reasonably be expected to be more likely to be able to repay a loan. However, educational attainment is still a proxy for race in the United States due to historical (and ongoing) discrimination.

Pre-processing generally modifies the dataset used for learning. Techniques in this category include:

  • Oversampling/undersampling: instead of training a machine learning model on all of the data, undersample the majority class by removing some of the majority class samples from the dataset in order to have a more balanced dataset. Alternatively, oversample the minority class by duplicating samples belonging to this group.

  • Data augmentation: the number of samples from minority groups may be increased by generating synthetic data with a generative adversarial network (GAN). We won’t cover this method in this workshop (using a GAN can be more computationally expensive than other techniques). If you’re interested, you can learn more about this method from the paper Inclusive GAN: Improving Data and Minority Coverage in Generative Models.

  • Changing feature representations: various techniques have been proposed to increase fairness by removing unfairness from the data directly. To do so, the data is converted into an alternate representation so that differences between demographic groups are minimized, yet enough information is maintained in order to be able to learn a model that performs well. An advantage of this method is that it is model-agnostic, however, a challenge is it reduces the interpretability of interpretable models and makes post-hoc explainability less meaningful for black-box models.

Pros and cons of preprocessing options

Discuss what you think the pros and cons of the different pre-processing options are. What techniques might work better in different settings?

A downside of oversampling is that it may violate statistical assumptions about independence of samples. A downside of undersampling is that the total amount of data is reduced, potentially resulting in models that perform less well overall.

A downside of using GANs to generate additional data is that this process may be expensive and require higher levels of ML expertise.

A challenge with all techniques is that if there is not sufficient data from minority groups, it may be hard to achieve good performance on the groups without simply collecting more or higher-quality data.

In-processing modifies the learning algorithm. Some specific in-processing techniques include:

  • Reweighting samples: many machine learning models allow for reweighting individual samples, i.e., indicating that misclassifying certain, rarer, samples should be penalized more severely in the loss function. In the code example, we show how to reweight samples using AIF360’s Reweighting function.

  • Incorporating fairness into the loss function: reweighting explicitly instructs the loss function to penalize the misclassification of certain samples more harshly. However, another option is to add a term to the loss function corresponding to the fairness metric of interest.

Post-processing modifies an existing model to increase its fairness. Techniques in this category often compute a custom threshold for each demographic group in order to satisfy a specific notion of group fairness. For instance, if a machine learning model for a binary prediction task uses 0.5 as a cutoff (e.g., raw scores less than 0.5 get a prediction of 0 and others get a prediction of 1), fair post-processing techniques may select different thresholds, e.g., 0.4 or 0.6 for different demographic groups.

In the code, we explore two different bias mitigations strategies implemented in the AIF360 Fairness Toolkit.


PYTHON

import numpy as np
import pandas as pd

from IPython.display import Markdown, display

%matplotlib inline
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline

from aif360.metrics import BinaryLabelDatasetMetric
from aif360.metrics import ClassificationMetric
from aif360.explainers import MetricTextExplainer
from aif360.algorithms.preprocessing import Reweighing
from aif360.algorithms.preprocessing import OptimPreproc
from aif360.datasets import MEPSDataset19

from fairlearn.postprocessing import ThresholdOptimizer

from collections import defaultdict

This notebook is adapted from AIF360’s Medical Expenditure Tutorial.

The tutorial uses data from the Medical Expenditure Panel Survey. We include a short description of the data below. For more details, especially on the preprocessing, please see the AIF360 tutorial. ## Scenario and data

The goal is to develop a healthcare utilization scoring model – i.e., to predict which patients will have the highest utilization of healthcare resources.

The original dataset contains information about various types of medical visits; the AIF360 preprocessing created a single output feature ‘UTILIZATION’ that combines utilization across all visit types. Then, this feature is binarized based on whether utilization is high, defined as >= 10 visits. Around 17% of the dataset has high utilization.

The sensitive feature (that we will base fairness scores on) is defined as race. Other predictors include demographics, health assessment data, past diagnoses, and physical/mental limitations.

The data is divided into years (we follow the lead of AIF360’s tutorial and use 2015), and further divided into Panels. We use Panel 19 (the first half of 2015). ### Loading the data

First, the data needs to be moved into the correct location for the AIF360 library to find it. If you haven’t yet, run setup.sh to complete that step. (Then, restart the kernel and re-load the packages at the top of this file.)

First, we load the data. Next, we create the train/validation/test splits and setup information about the privileged and unprivileged groups. (Recall, we focus on race as the sensitive feature.)

PYTHON

(dataset_orig_panel19_train,
 dataset_orig_panel19_val,
 dataset_orig_panel19_test) = MEPSDataset19().split([0.5, 0.8], shuffle=True)

sens_ind = 0
sens_attr = dataset_orig_panel19_train.protected_attribute_names[sens_ind]

unprivileged_groups = [{sens_attr: v} for v in
                       dataset_orig_panel19_train.unprivileged_protected_attributes[sens_ind]]
privileged_groups = [{sens_attr: v} for v in
                     dataset_orig_panel19_train.privileged_protected_attributes[sens_ind]]

Show details about the data.

PYTHON

def describe(train=None, val=None, test=None):
    if train is not None:
        display(Markdown("#### Training Dataset shape"))
        print(train.features.shape)
    if val is not None:
        display(Markdown("#### Validation Dataset shape"))
        print(val.features.shape)
    display(Markdown("#### Test Dataset shape"))
    print(test.features.shape)
    display(Markdown("#### Favorable and unfavorable labels"))
    print(test.favorable_label, test.unfavorable_label)
    display(Markdown("#### Protected attribute names"))
    print(test.protected_attribute_names)
    display(Markdown("#### Privileged and unprivileged protected attribute values"))
    print(test.privileged_protected_attributes, 
          test.unprivileged_protected_attributes)
    display(Markdown("#### Dataset feature names\n See [MEPS documentation](https://meps.ahrq.gov/data_stats/download_data/pufs/h181/h181doc.pdf) for details on the various features"))
    print(test.feature_names)

describe(dataset_orig_panel19_train, dataset_orig_panel19_val, dataset_orig_panel19_test)

Next, we will look at whether the dataset contains bias; i.e., does the outcome ‘UTILIZATION’ take on a positive value more frequently for one racial group than another?

The disparate impact score will be between 0 and 1, where 1 indicates no bias.

PYTHON

metric_orig_panel19_train = BinaryLabelDatasetMetric(
        dataset_orig_panel19_train,
        unprivileged_groups=unprivileged_groups,
        privileged_groups=privileged_groups)
explainer_orig_panel19_train = MetricTextExplainer(metric_orig_panel19_train)

print(explainer_orig_panel19_train.disparate_impact())

We see that the disparate impact is about 0.48, which means the privileged group has the favorable outcome at about 2x the rate as the unprivileged group does.

(In this case, the “favorable” outcome is label=1, i.e., high utilization) ## Train a model

We will train a logistic regression classifier.

PYTHON

dataset = dataset_orig_panel19_train
model = make_pipeline(StandardScaler(),
                      LogisticRegression(solver='liblinear', random_state=1))
fit_params = {'logisticregression__sample_weight': dataset.instance_weights}

lr_orig_panel19 = model.fit(dataset.features, dataset.labels.ravel(), **fit_params)

Validate the model

Recall that a logistic regression model can output probabilities (i.e., model.predict(dataset).scores) and we can determine our own threshold for predicting class 0 or 1.

The following function, test, computes performance on the logistic regression model based on a variety of thresholds, as indicated by thresh_arr, an array of threshold values. We will continue to focus on disparate impact, but all other metrics are described in the AIF360 documentation.

PYTHON

def test(dataset, model, thresh_arr):
    try:
        # sklearn classifier
        y_val_pred_prob = model.predict_proba(dataset.features)
        pos_ind = np.where(model.classes_ == dataset.favorable_label)[0][0]
    except AttributeError as e:
        print(e)
        # aif360 inprocessing algorithm
        y_val_pred_prob = model.predict(dataset).scores
        pos_ind = 0
        
    pos_ind = np.where(model.classes_ == dataset.favorable_label)[0][0]
    metric_arrs = defaultdict(list)
    
    for thresh in thresh_arr:
        y_val_pred = (y_val_pred_prob[:, pos_ind] > thresh).astype(np.float64)

        dataset_pred = dataset.copy()
        dataset_pred.labels = y_val_pred
        metric = ClassificationMetric(
                dataset, dataset_pred,
                unprivileged_groups=unprivileged_groups,
                privileged_groups=privileged_groups)

        # various metrics - can look up what they are on your own
        metric_arrs['bal_acc'].append((metric.true_positive_rate()
                                     + metric.true_negative_rate()) / 2)
        metric_arrs['avg_odds_diff'].append(metric.average_odds_difference())
        metric_arrs['disp_imp'].append(metric.disparate_impact())
        metric_arrs['stat_par_diff'].append(metric.statistical_parity_difference())
        metric_arrs['eq_opp_diff'].append(metric.equal_opportunity_difference())
        metric_arrs['theil_ind'].append(metric.theil_index())
    
    return metric_arrs

PYTHON

thresh_arr = np.linspace(0.01, 0.5, 50)
val_metrics = test(dataset=dataset_orig_panel19_val,
                   model=lr_orig_panel19,
                   thresh_arr=thresh_arr)
lr_orig_best_ind = np.argmax(val_metrics['bal_acc'])

We will plot val_metrics. The x-axis will be the threshold we use to output the label 1 (i.e., if the raw score is larger than the threshold, we output 1).

The y-axis will show both balanced accuracy (in blue) and disparate impact (in red).

Note that we plot 1 - Disparate Impact, so now a score of 0 indicates no bias.

PYTHON

def plot(x, x_name, y_left, y_left_name, y_right, y_right_name):
    fig, ax1 = plt.subplots(figsize=(10,7))
    ax1.plot(x, y_left)
    ax1.set_xlabel(x_name, fontsize=16, fontweight='bold')
    ax1.set_ylabel(y_left_name, color='b', fontsize=16, fontweight='bold')
    ax1.xaxis.set_tick_params(labelsize=14)
    ax1.yaxis.set_tick_params(labelsize=14)
    ax1.set_ylim(0.5, 0.8)

    ax2 = ax1.twinx()
    ax2.plot(x, y_right, color='r')
    ax2.set_ylabel(y_right_name, color='r', fontsize=16, fontweight='bold')
    if 'DI' in y_right_name:
        ax2.set_ylim(0., 0.7)
    else:
        ax2.set_ylim(-0.25, 0.1)

    best_ind = np.argmax(y_left)
    ax2.axvline(np.array(x)[best_ind], color='k', linestyle=':')
    ax2.yaxis.set_tick_params(labelsize=14)
    ax2.grid(True)

PYTHON

disp_imp = np.array(val_metrics['disp_imp'])
disp_imp_err = 1 - disp_imp
plot(thresh_arr, 'Classification Thresholds',
     val_metrics['bal_acc'], 'Balanced Accuracy',
     disp_imp_err, '1 - DI')

If you like, you can plot other metrics, e.g., average odds difference.

In the next cell, we write a function to print out a variety of other metrics. Since we look at 1 - disparate impact, all of these metrics have a value of 0 if they are perfectly fair. Again, you can learn more details about the various metrics in the AIF360 documentation.

PYTHON

def describe_metrics(metrics, thresh_arr):
    best_ind = np.argmax(metrics['bal_acc'])
    print("Threshold corresponding to Best balanced accuracy: {:6.4f}".format(thresh_arr[best_ind]))
    print("Best balanced accuracy: {:6.4f}".format(metrics['bal_acc'][best_ind]))
    disp_imp_at_best_ind = 1 - metrics['disp_imp'][best_ind]
    print("\nCorresponding 1-DI value: {:6.4f}".format(disp_imp_at_best_ind))
    print("Corresponding average odds difference value: {:6.4f}".format(metrics['avg_odds_diff'][best_ind]))
    print("Corresponding statistical parity difference value: {:6.4f}".format(metrics['stat_par_diff'][best_ind]))
    print("Corresponding equal opportunity difference value: {:6.4f}".format(metrics['eq_opp_diff'][best_ind]))
    print("Corresponding Theil index value: {:6.4f}".format(metrics['theil_ind'][best_ind]))

describe_metrics(val_metrics, thresh_arr)

Test the model

Now that we have used the validation data to select the best threshold, we will evaluate the test the model on the test data.

PYTHON

lr_metrics = test(dataset=dataset_orig_panel19_test,
                       model=lr_orig_panel19,
                       thresh_arr=[thresh_arr[lr_orig_best_ind]])
describe_metrics(lr_metrics, [thresh_arr[lr_orig_best_ind]])

Mitigate bias with in-processing


We will use reweighting as an in-processing step to try to increase fairness. AIF360 has a function that performs reweighting that we will use. If you’re interested, you can look at details about how it works in the documentation.

If you look at the documentation, you will see that AIF360 classifies reweighting as a preprocessing, not an in-processing intervention. Technically, AIF360’s implementation modifies the dataset, not the learning algorithm so it is pre-processing. But, it is functionally equivalent to modifying the learning algorithm’s loss function, so we follow the convention of the fair ML field and call it in-processing.

PYTHON

# Reweighting is a AIF360 class to reweight the data 
RW = Reweighing(unprivileged_groups=unprivileged_groups,
                privileged_groups=privileged_groups)
dataset_transf_panel19_train = RW.fit_transform(dataset_orig_panel19_train)

We’ll also define metrics for the reweighted data and print out the disparate impact of the dataset.

PYTHON

metric_transf_panel19_train = BinaryLabelDatasetMetric(
        dataset_transf_panel19_train,
        unprivileged_groups=unprivileged_groups,
        privileged_groups=privileged_groups)
explainer_transf_panel19_train = MetricTextExplainer(metric_transf_panel19_train)

print(explainer_transf_panel19_train.disparate_impact())

Then, we’ll train a model, validate it, and evaluate of the test data.

PYTHON

# train
dataset = dataset_transf_panel19_train
model = make_pipeline(StandardScaler(),
                      LogisticRegression(solver='liblinear', random_state=1))
fit_params = {'logisticregression__sample_weight': dataset.instance_weights}
lr_transf_panel19 = model.fit(dataset.features, dataset.labels.ravel(), **fit_params)

PYTHON

# validate
thresh_arr = np.linspace(0.01, 0.5, 50)
val_metrics = test(dataset=dataset_orig_panel19_val,
                   model=lr_transf_panel19,
                   thresh_arr=thresh_arr)
lr_transf_best_ind = np.argmax(val_metrics['bal_acc'])

PYTHON

# plot validation results
disp_imp = np.array(val_metrics['disp_imp'])
disp_imp_err = 1 - np.minimum(disp_imp, 1/disp_imp)
plot(thresh_arr, 'Classification Thresholds',
     val_metrics['bal_acc'], 'Balanced Accuracy',
     disp_imp_err, '1 - min(DI, 1/DI)')

PYTHON

# describe validation results
describe_metrics(val_metrics, thresh_arr)

Test

lr_transf_metrics = test(dataset=dataset_orig_panel19_test, model=lr_transf_panel19, thresh_arr=[thresh_arr[lr_transf_best_ind]]) describe_metrics(lr_transf_metrics, [thresh_arr[lr_transf_best_ind]]) We see that the disparate impact score on the test data is better after reweighting than it was originally.

How do the other fairness metrics compare? ## Mitigate bias with preprocessing We will use a method, ThresholdOptimizer, that is implemented in the library Fairlearn. ThresholdOptimizer finds custom thresholds for each demographic group so as to achieve parity in the desired group fairness metric.

We will focus on demographic parity, but feel free to try other metrics if you’re curious on how it does.

The first step is creating the ThresholdOptimizer object. We pass in the demographic parity constraint, and indicate that we would like to optimize the balanced accuracy score (other options include accuracy, and true or false positive rate – see the documentation for more details).

PYTHON

to = ThresholdOptimizer(estimator=model, constraints="demographic_parity", objective="balanced_accuracy_score", prefit=True)

Next, we fit the ThresholdOptimizer object to the validation data.

PYTHON

to.fit(dataset_orig_panel19_val.features, dataset_orig_panel19_val.labels, 
       sensitive_features=dataset_orig_panel19_val.protected_attributes[:,0])

Then, we’ll create a helper function, mini_test to allow us to call the describe_metrics function even though we are no longer evaluating our method as a variety of thresholds.

After that, we call the ThresholdOptimizer’s predict function on the validation and test data, and then compute metrics and print the results.

PYTHON

def mini_test(dataset, preds):
    metric_arrs = defaultdict(list)
    dataset_pred = dataset.copy()
    dataset_pred.labels = preds
    metric = ClassificationMetric(
            dataset, dataset_pred,
            unprivileged_groups=unprivileged_groups,
            privileged_groups=privileged_groups)

    # various metrics - can look up what they are on your own
    metric_arrs['bal_acc'].append((metric.true_positive_rate()
                                    + metric.true_negative_rate()) / 2)
    metric_arrs['avg_odds_diff'].append(metric.average_odds_difference())
    metric_arrs['disp_imp'].append(metric.disparate_impact())
    metric_arrs['stat_par_diff'].append(metric.statistical_parity_difference())
    metric_arrs['eq_opp_diff'].append(metric.equal_opportunity_difference())
    metric_arrs['theil_ind'].append(metric.theil_index())
    
    return metric_arrs

PYTHON

to_val_preds = to.predict(dataset_orig_panel19_val.features, sensitive_features=dataset_orig_panel19_val.protected_attributes[:,0])
to_test_preds = to.predict(dataset_orig_panel19_test.features, sensitive_features=dataset_orig_panel19_test.protected_attributes[:,0])

PYTHON

to_val_metrics = mini_test(dataset_orig_panel19_val, to_val_preds)
to_test_metrics = mini_test(dataset_orig_panel19_test, to_test_preds)

PYTHON

print("Remember, `Threshold corresponding to Best balanced accuracy` is just a placeholder here.")
describe_metrics(to_val_metrics, [0])

PYTHON

print("Remember, `Threshold corresponding to Best balanced accuracy` is just a placeholder here.")

describe_metrics(to_test_metrics, [0])

Scroll up and see how these results compare with the original classifier and with the in-processing technique.

A major difference is that the accuracy is lower, now. In practice, it might be better to use an algorithm that allows a custom tradeoff between the accuracy sacrifice and increased levels of fairness.

We can also see what threshold is being used for each demographic group by examining the interpolated_thresholder_.interpretation_dict property of the ThresholdOptimzer.

PYTHON

threshold_rules_by_group = to.interpolated_thresholder_.interpolation_dict
threshold_rules_by_group

Recall that a value of 1 in the Race column corresponds to White people, while a value of 0 corresponds to non-White people.

Due to the inherent randomness of the ThresholdOptimizer, you might get slightly different results than your neighbors. When we ran the previous cell, the output was

{0.0: {'p0': 0.9287205987170348, 'operation0': [>0.5], 'p1': 0.07127940128296517, 'operation1': [>-inf]}, 1.0: {'p0': 0.002549618320610717, 'operation0': [>inf], 'p1': 0.9974503816793893, 'operation1': [>0.5]}}

This tells us that for non-White individuals:

  • If the score is above 0.5, predict 1.

  • Otherwise, predict 1 with probability 0.071

And for White individuals:

  • If the score is above 0.5, predict 1 with probability 0.997

Discussion question: what are the pros and cons of improving the model fairness by introducing randomization?

Key Points

  • It’s important to consider many dimensions of model performance: a single accuracy score is not sufficient.
  • There is no single definition of “fair machine learning”: different notions of fairness are appropriate in different contexts.
  • Representational harms and stereotypes can be perpetuated by generative AI.
  • The fairness of a model can be improved by using techniques like data reweighting and model postprocessing.

Content from Interpretablility versus explainability


Last updated on 2024-07-17 | Edit this page

Overview

Questions

  • What are popular machine learning models?
  • What are model intepretability and model explainability? Why are they important?
  • Which should you choose: interpretable models or explainable models?

Objectives

  • Showcase machine learning models that are widely used in practice.
  • Understand and distinguish between explainable machine learning models and interpretable machine learning models.
  • Describe two reasons when deciding which model to choose.

Key Points

  • Model Explainability vs. Model Interpretability:
    • Interpretability: Refers to the degree to which a human can understand the cause of a decision made by a model. It is essential for verifying the correctness of the model, ensuring compliance with regulations, and enabling effective debugging.
    • Explainability: Refers to the extent to which the internal mechanics of a machine learning model can be explained in human terms. It is crucial for understanding how models make decisions, ensuring transparency, and building trust with stakeholders.
  • Choosing Between Explainable and Interpretable Models:
    • When Transparency is Critical: Opt for interpretable models (e.g., linear regression, decision trees) when it is essential to have a clear understanding of how decisions are made, such as in healthcare or finance.
    • When Performance is a Priority: Choose explainable models (e.g., neural networks, gradient boosting machines) when predictive accuracy is the primary concern, and you can use explanation methods to understand model behavior.

Exercise 1: Model Selection for Predicting COVID-19 Progression, a study by Giotta et al.

Objective:

To predict bad outcomes (death or transfer to an intensive care unit) from COVID-19 patients using hematological, biochemical, and inflammatory biomarkers.

Motivation:

In the early days of the COVID-19 pandemic, healthcare professionals around the world faced unprecedented challenges. Predicting the progression of the disease and identifying patients at high risk of severe outcomes became crucial for effective treatment and resource allocation. One such study, published on the National Center for Biotechnology Information (NCBI) website, investigated the characteristics of patients who either succumbed to the disease or required intensive care compared to those who recovered.

This study highlighted the critical role of various biomarkers, such as hematological, biochemical, and inflammatory markers, in understanding disease progression. However, simply identifying these markers was not enough. Clinicians needed tools that could not only predict outcomes with high accuracy but also provide clear, understandable reasons for their predictions.

Dataset Specification: Hematological biomarkers included white blood cells, neutrophils count, lymphocytes count, monocytes count, eosinophils count, platelet count, cluster of differentiation (CD)4, CD8 percentages, and hemoglobin. Biochemical markers were albumin, alanine aminotransferase, aspartate aminotransferase, total bilirubin, creatinine, creatinine kinase, lactate dehydrogenase (LDH), cardiac troponin I, myoglobin, and creatine kinase-MB. The coagulation markers were prothrombin time, activated partial thromboplastin time (APTT), and D-dimer. The inflammatory biomarkers were C-reactive protein (CRP), serum ferritin, procalcitonin (PCT), erythrocyte sedimentation rate, and interleukin and tumor necrosis factor-alpha (TNFα) levels.

Some statistics from the dataset:

Table 1: Main characteristics of the patients included in the study at baseline and results of comparison of percentage between outcome using chi-square or Fisher exact test.

Death or Transferred to Intensive Care Unit (n = 32) Discharged Alive (n = 113) p-Value
N % N
Sex
Male 18 56.25% 61
Female 14 43.75% 52
Symptoms
Dyspnea 12 37.50% 52
Cough 5 15.63% 35
Fatigue 7 21.88% 30
Headache 2 6.25% 12
Confusion 1 3.13% 9
Nausea 1 3.13% 8
Sick 1 3.13% 6
Pharyngitis 1 3.13% 6
Nasal congestion 1 3.13% 3
Arthralgia 0 0.00% 3
Myalgia 1 3.13% 2
Arrhythmia 3 9.38% 12
Comorbidity
Hypertension 12 37.50% 71
Cardiovascular disease 12 37.50% 43
Diabetes 11 34.38% 35
Cerebrovascular disease 9 28.13% 19
Chronic kidney disease 8 25.00% 14
COPD 5 15.63% 14
Tumors 5 15.63% 11
Hepatitis B 0 0.00% 6
Immunopathological disease 1 3.13% 5

Table 2: Comparison of clinical characteristics and laboratory findings between patients who died or were transferred to ICU and those who were discharged alive.

Patients Deaths or Transferred to ICU (n = 32) Patients Alive (n = 113) p-Value
Median Q1 Q3
Age (years) 78.0 67.0 85.75
Temperature (°C) 36.5 36.0 36.9
Respiratory rate (rpm) 20.0 18.0 20.0
Cardiac frequency (rpm) 79.0 70.0 90.0
Systolic blood pressure (mmHg) 137.5 116.0 150.0
Diastolic blood pressure (mmHg) 77.5 65.0 83.0
Temperature at admission (°C) 36.0 35.7 36.4
Percentage of O2 saturation 90.0 87.0 95.0
FiO2 (%) 100.0 96.0 100.0
**Neutrophil count (*10^3/µL)** 7.98 4.75 10.5
**Lymphocyte count (*10^3/µL)** 1.34 0.85 1.98
**Platelet count (*10^3/µL)** 202.00 147.5 272.25
Hemoglobin level (g/dL) 12.7 11.8 14.5
Procalcitonin levels (ng/mL) 0.11 0.07 0.27
CRP (mg/dL) 8.06 2.9 16.1
LDH (mg/dL) 307.0 258.5 386.0
Albumin (mg/dL) 27.0 24.5 32.5
ALT (mg/dL) 23.0 12.0 47.5
AST (mg/dL) 30.0 22.0 52.5
ALP (mg/dL) 70.0 53.5 88.0
Direct bilirubin (mg/dL) 0.15 0.1 0.27
Indirect bilirubin (mg/dL) 0.15 0.012 0.002
Total bilirubin (mg/dL) 0.3 0.2 0.6
Creatinine (mg/dL) 1.03 0.6 1.637
CPK (mg/dL) 79.0 47.0 194.0
Sodium (mg/dL) 140.0 137.0 142.5
Potassium (mg/dL) 4.4 4.0 5.0
INR 1.1 1.0 1.2
IL-6 (pg/mL) 88.8 13.7 119.7
IgM (AU/mL) 3.4 0.0 8.1
IgG (AU/mL) 12.0 5.7 13.4
Length of stay (days) 11.0 5.75 17.0

Real-World Impact:

During the pandemic, numerous studies and models were developed to aid in predicting COVID-19 outcomes. The study from this paper serves as an excellent example of how detailed patient data can inform model development. By designing a suitable machine learning model, researchers and healthcare providers can not only achieve high predictive accuracy but also ensure that their findings are actionable and trustworthy.

Discussion Questions:

  1. Compare the Advantages:
    • What are the advantages of using explainable models such as decision trees in predicting COVID-19 outcomes?
    • What are the advantages of using black box models such as neural networks in this scenario?
  2. Assess the Drawbacks:
    • What are the potential drawbacks of using explainable models like decision trees?
    • What are the potential drawbacks of using black box models in healthcare settings?
  3. Decision-Making Criteria:
    • In what situations might you prioritize an explainable model over a black box model, and why?
    • Are there scenarios where the higher accuracy of black box models justifies their use despite their lack of transparency?
  4. Practical Application:
    • Design a simple decision tree based on the provided biomarkers to predict bad outcomes.
    • Evaluate how the decision tree can aid healthcare providers in making informed decisions.
  1. Compare the Advantages:
    • Explainable Models: Allow healthcare professionals to understand and trust the model’s decisions, providing clear insights into which biomarkers contribute most to predicting bad outcomes. This transparency is crucial in critical fields such as healthcare, where understanding the decision-making process can inform treatment plans and improve patient outcomes.
    • Black Box Models: Often provide higher predictive accuracy, which can be crucial for identifying patterns in complex datasets. They can capture non-linear relationships and interactions that simpler models might miss.
  2. Assess the Drawbacks:
    • Explainable Models: May not capture complex relationships in the data as effectively as black box models, potentially leading to lower predictive accuracy in some cases.
    • Black Box Models: Can be difficult to interpret, which hinders trust and adoption by medical professionals. Without understanding the model’s reasoning, it becomes challenging to validate its correctness, ensure regulatory compliance, and effectively debug or refine the model.
  3. Decision-Making Criteria:
    • Prioritizing Explainable Models: When transparency, trust, and regulatory compliance are critical, such as in healthcare settings where understanding and validating decisions is essential.
    • Using Black Box Models: When the need for high predictive accuracy outweighs the need for transparency, and when supplementary methods for interpreting the model’s output can be employed.
  4. Practical Application:
    • Design a Decision Tree: Using the given biomarkers, create a simple decision tree. Identify key split points (e.g., high CRP levels, elevated LDH) and illustrate how these markers can be used to predict bad outcomes. Tools like scikit-learn or any decision tree visualization tool can be used.
    • Example Decision Tree: Here is a Decision Tree found by Giotta et al.

Exercise2: COVID-19 Diagnosis Using Chest X-Rays, a study by Ucar and Korkmaz

Objective: Diagnose COVID-19 through chest X-rays.

Motivation:

The COVID-19 pandemic has had an unprecedented impact on global health, affecting millions of people worldwide. One of the critical challenges in managing this pandemic is the rapid and accurate diagnosis of infected individuals. Traditional methods, such as the Reverse Transcription Polymerase Chain Reaction (RT-PCR) test, although widely used, have several drawbacks. These tests are time-consuming, require specialized equipment and personnel, and often suffer from low detection rates, necessitating multiple tests to confirm a diagnosis.

In this context, radiological imaging, particularly chest X-rays, has emerged as a valuable tool for COVID-19 diagnosis. Early studies have shown that COVID-19 causes specific abnormalities in chest X-rays, such as ground-glass opacities, which can be used as indicators of the disease. However, interpreting these images requires expertise and time, both of which are in short supply during a pandemic.

To address these challenges, researchers have turned to machine learning techniques…

Dataset Specification: Chest X-ray images

Real-World Impact:

The COVID-19 pandemic highlighted the urgent need for rapid and accurate diagnostic tools. Traditional methods like RT-PCR tests, while effective, are often time-consuming and have variable detection rates. Using chest X-rays for diagnosis offers a quicker and more accessible alternative. By analyzing chest X-rays, healthcare providers can swiftly identify COVID-19 cases, enabling timely treatment and isolation measures. Developing a machine learning method that can quickly and accurately analyze chest X-rays can significantly enhance the speed and efficiency of the healthcare response, especially in areas with limited access to RT-PCR testing.

Discussion Questions:

  1. Compare the Advantages:
    • What are the advantages of using deep neural networks in diagnosing COVID-19 from chest X-rays?
    • What are the advantages of traditional methods, such as genomic data analysis, for COVID-19 diagnosis?
  2. Assess the Drawbacks:
    • What are the potential drawbacks of using deep neural networks for COVID-19 diagnosis from chest X-rays?
    • How do these drawbacks compare to those of traditional methods?
  3. Decision-Making Criteria:
    • In what situations might you prioritize using deep neural networks over traditional methods, and why?
    • Are there scenarios where the rapid availability of X-ray results justifies the use of deep neural networks despite potential drawbacks?
  4. Practical Application:
    • Design a simple deep neural network architecture for diagnosing COVID-19 from chest X-rays.
    • Evaluate how this deep learning model can aid healthcare providers in making informed decisions quickly.
  1. Compare the Advantages:
    • Deep Neural Networks: Provide high accuracy (e.g., 98%) in diagnosing COVID-19 from chest X-rays, offering a quick and non-invasive diagnostic tool. They can handle large amounts of image data and identify complex patterns that might be missed by human eyes.
    • Traditional Methods: Provide detailed and specific diagnostic information by analyzing genomic data and biomarkers, which can be crucial for understanding the virus’s behavior and patient response.
  2. Assess the Drawbacks:
    • Deep Neural Networks: Require large labeled datasets for training, which may not always be available. The models can be seen as “black boxes”, making it challenging to interpret their decisions without additional explainability methods.
    • Traditional Methods: Time-consuming and may have lower detection accuracy. They often require specialized equipment and personnel, leading to delays in diagnosis.
  3. Decision-Making Criteria:
    • Prioritizing Deep Neural Networks: When rapid diagnosis is critical, and chest X-rays are readily available. Useful in large-scale screening scenarios where speed is more critical than the detailed understanding provided by genomic data.
    • Using Traditional Methods: When detailed and specific information about the virus is needed for treatment planning, and when the availability of genomic data and biomarkers is not a bottleneck.
  4. Practical Application:
    • Design a Neural Network: Create a simple convolutional neural network (CNN) architecture using tools like TensorFlow or PyTorch. Use a dataset of labeled chest X-ray images to train and validate the model.

    • Example Model: Here is a model proposed by Ucar and Korkmaz

      • Evaluate the Model: Train the model on your dataset and evaluate its performance. Discuss how this model can help healthcare providers make quick and accurate diagnoses.

Content from Explainability methods overview


Last updated on 2024-07-10 | Edit this page

Overview

Questions

  • TODO

Objectives

  • TODO

Fantastic Explainability Methods and Where to Use Them


We will now take a bird’s-eye view of explainability methods that are widely applied on complex models like neural networks. We will get a sense of when to use which kind of method, and what the tradeoffs between these methods are.

Three axes of use cases for understanding model behavior


When deciding which explainability method to use, it is helpful to define your setting along three axes. This helps in understanding the context in which the model is being used, and the kind of insights you are looking to gain from the model.

Inherently Interpretable vs Post Hoc Explainable

Understanding the tradeoff between interpretability and complexity is crucial in machine learning. Simple models like decision trees, random forests, and linear regression offer transparency and ease of understanding, making them ideal for explaining predictions to stakeholders. In contrast, neural networks, while powerful, lack interpretability due to their complexity. Post hoc explainable techniques can be applied to neural networks to provide explanations for predictions, but it’s essential to recognize that using such methods involves a tradeoff between model complexity and interpretability.

Striking the right balance between these factors is key to selecting the most suitable model for a given task, considering both its predictive performance and the need for interpretability.

_Credits: AAAI 2021 Tutorial on Explaining Machine Learning Predictions: State of the Art, Challenges, Opportunities._
The tradeoff between Interpretability and Complexity

Local vs Global Explanations

Local explanations focus on describing model behavior within a specific neighborhood, providing insights into individual predictions. Conversely, global explanations aim to elucidate overall model behavior, offering a broader perspective. While global explanations may be more comprehensive, they run the risk of being overly complex.

Both types of explanations are valuable for uncovering biases and ensuring that the model makes predictions for the right reasons. The tradeoff between local and global explanations has a long history in statistics, with methods like linear regression (global) and kernel smoothing (local) illustrating the importance of considering both perspectives in statistical analysis.

Black box vs White Box Approaches

Techniques that require access to model internals (e.g., model architecture and model weights) are called “white box” while techniques that only need query access to the model are called “black box”. Even without access to the model weights, black box or top down approaches can shed a lot of light on model behavior. For example, by simply evaluating the model on certain kinds of data, high level biases or trends in the model’s decision making process can be unearthed.

White box approaches use the weights and activations of the model to understand its behavior. These classes or methods are more complex and diverse, and we will discuss them in more detail later in this episode. Some large models are closed-source due to commercial or safety concerns; for example, users can’t get access to the weights of GPT-4. This limits the use of white box explanations for such models.

Classes of Explainability Methods for Understanding Model Behavior


Diagnostic Testing

This is the simplest approach towards explaining model behavior. This involves applying a series of unit tests to the model, where each test is a sample input where you know what the correct output should be. By identifying test examples that break the heuristics the model relies on (called counterfactuals), you can gain insights into the high-level behavior of the model.

Example Methods: Counterfactuals, Unit tests

Pros and Cons: These methods allow for gaining insights into the high-level behavior of the model without the needing access to model weights. This is especially useful with recent powerful closed-source models like GPT-4. One challenge with this approach is that it is hard to identify in advance what heuristics a model may depend on.

Baking interpretability into models

Some recent research has focused on tweaking highly complex models like neural networks, towards making them more interpretable inherently. One such example with language models involves training the model to generate rationales for its prediction, in addition to its original prediction. This approach has gained some traction, and there are even public benchmarks for evaluating the quality of these generated rationales.

Example methods: Rationales with WT5, Older approaches for rationales

Pros and cons: These models hope to achieve the best of both worlds: complex models that are also inherently interpretable. However, research in this direction is still new, and there are no established and reliable approaches for real world applications just yet.

Identifying Decision Rules of the Model:

In this class of methods, we try find a set of rules that generally explain the decision making process of the model. Loosely, these rules would be of the form “if a specific condition is met, then the model will predict a certain class”.

Example methods: Anchors, Universal Adversarial Triggers

Table caption: "Generated anchors for Tabular datasets". Table shows the following rules: for the adult dataset, predict less than 50K if no capital gain or loss and never married. Predict over 50K if country is US, married, and work hours over 45. For RCDV dataset, predict not rearrested if person has no priors, no prison violations, and crime not against property. Predict re-arrested if person is male, black, has 1-5 priors, is not married, and the crime not against property. For the Lending dataset, predict bad loan if FICO score is less than 650. Predict good loan if FICO score is between 650 and 700 and loan amount is between 5400 and 10000.
Example use of anchors (table from Ribeiro et al.)

Pros and cons: Some global rules help find “bugs” in the model, or identify high level biases. But finding such broad coverage rules is challenging. Furthermore, these rules only showcase the model’s weaknesses, but give next to no insight as to why these weaknesses exist.

Visualizing model weights or representations

Just like how a picture tells a thousand words, visualizations can help encapsulate complex model behavior in a simple image. Visualizations are commonly used in explaining neural networks, where the weights or data representations of the model are directly visualized. Many such approaches involve reducing the high-dimensional weights or representations to a 2D or 3D space, using techniques like PCA, tSNE, or UMAP. Alternatively, these visualizations can retain their high dimensional representation, but use color or size to identify which dimensions or neurons are more important.

Example methods: Visualizing attention heatmaps, Weight visualizations, Model activation visualizations

Image shows a grid with 3 rows and 50 columns. Each cell is colored on a scale of -1.5 (white) to 0.9 (dark blue). Darker colors are concentrated in the first row in seemingly-random columns.
Example usage of visualizing attention heatmaps for part-of-speech (POS) identification task using word2vec-encoded vectors. Each cell is a unit in a neural network (each row is a layer and each column is a dimension). Darker colors indicates that a unit is more importance for predictive accuracy (table from Li et al..)

Pros and cons: Gleaning model behaviour from visualizations is very intuitive and user-friendly, and visualizations sometimes have interactive interfaces. However, visualizations can be misleading, especially when high-dimensional vectors are reduced to 2D, leading to a loss of information (crowding issue).

An iconic debate exemplifying the validity of visualizations has centered around attention heatmaps. Research has shown them to be unreliable, and then reliable again. (Check out the titles of these papers!) Thus, visualization can only be used as an additional step in an analysis, and not as a standalone method.

Understanding the impact of training examples

These techniques unearth which training data instances caused the model to generate a specific prediction for a given sample. At a high level, these techniques mathematically identify what training samples that – if removed from the training process – are most influential for causing a particular prediction.

Example methods: Influence functions, Representer point selection

Two images. On the left, several antelope are standing in the background on a grassy field. On the right, several zebra graze in a field in the background, while there is one antelope in the foreground and other antelope in the background.
Example usage of representer point selection. The image on the left is a test image that is misclassified as a deer (the true label is antelope). The image on the right is the most influential training point. We see that this image is labeled “zebra,” but contains both zebras and antelopes. (example adapted from Yeh et al..)

Pros and cons: The insights from these approaches are actionable - by identifying the data responsible for a prediction, it can help correct labels or annotation artifacts in that data. Unfortunately, these methods scale poorly with the size of the model and training data, quickly becoming computationally expensive. Furthermore, even knowing which datapoints had a high influence on a prediction, we don’t know what it was about that datapoint that caused the influence.

Understanding the impact of a single example:

For a single input, what parts of the input were most important in generating the model’s prediction? These methods study the signal sent by various features to the model, and observe how the model reacts to changes in these features.

Example methods: Saliency Maps, LIME/SHAP, Perturbations (Input reduction, Adversarial Perturbations)

These methods can be further subdivided into two categories: gradient-based methods that rely on white-box model access to directly see the impact of changing a single input, and perturbation-based methods that manually perturb an input and re-query the model to see how the prediction changes.

Two rows images (5 images per row). Leftmost column shows two different pictures, each containing a cat and a dog. Remaining columns show the saliency maps using different techniques (VanillaGrad, InteGrad, GuidedBackProp, and SmoothGrad). Each saliency map has red dots (indicated regions that are influential for predicting "dog") and blue dots (influential for predicting "cat"). All methods except GuidedBackProp have good overlap between the respective dots and where the animals appear in the image. SmoothGrad has the most precise mapping.
Example saliency maps. The right 4 columns show the result of different saliency method techniques, where red dots indicate regions that are influential for predicting “dog” and blue dots indicate regions that are influential for predicting “cat”. The image creators argue that their method, SmoothGrad, is most effective at mapping model behavior to images. (Image taken from Smilkov et al.)

Pros and cons: These methods are fast to compute, and flexible in their use across models. However, the insights gained from these methods are not actionable - knowing which part of the input caused the prediction does not highlight why that part caused it. On finding issues in the prediction process, it is also hard to pick up on if there is an underlying issue in the model, or just the specific inputs tested on. Relatedly, these methods can be unstable, and can even be fooled by adversarial examples.

Probing internal representations

As the name suggests, this class of methods aims to probe the internals of a model, to discover what kind of information or knowledge is stored inside the model. Probes are often administered to a specific component of the model, like a set of neurons or layers within a neural network.

Example methods: Probing classifiers, Causal tracing

The phrase "The nurse examined the farmer for injuries because PRONOUN" is shown twice, once with PRONOUN=she and once with PRONOUN=he. Each word is annotated with the importance of three different attention heads. The distribution of which heads are important with each pronoun differs for all words, but especially for nurse and farmer.
Example probe output. The image shows the result from probing three attention heads. We see that gender stereotypes are encoded into the model because the heads that are important for nurse and farmer change depending on the final pronoun. Specifically, Head 5-10 attends to the stereotypical gender assignment while Head 4-6 attends to the anti-stereotypical gender assignment. (Image taken from Vig et al.)

Pros and cons: Probes have shown that it is possible to find highly interpretable components in a complex model, e.g., MLP layers in transformers have been shown to store factual knowledge in a structured manner. However, there is no systematic way of finding interpretable components, and many components may remain elusive to humans to understand. Furthermore, the model components that have been shown to contain certain knowledge may not actually play a role in the model’s prediction.

Is that all?

Nope! We’ve discussed a few of the common explanation techniques, but many others exist. In particular, specialized model architectures often need their own explanation algorithms. For instance, Yuan et al. give an overview of different explanation techniques for graph neural networks (GNNs).

Classifying explanation techniques

For each of the explanation techniques described above, discuss the following with a partner:

  • Does it require black-box or white-box model access?
  • Are the explanations it provides global or local?
  • Is the technique post-hoc or does it rely on inherent interpretability of the model?
Approach Post Hoc or Inherently Interpretable? Local or Global? White Box or Black Box?
Diagnostic Testing Post Hoc Global Black Box
Baking interpretability into models Inherently Interpretable Local White Box
Identifying Decision Rules of the Model Post Hoc Both White Box
Visualizing model weights or representations Post Hoc Global White Box
Understanding the impact of training examples Post Hoc Local White Box
Understanding the impact of a single example Post Hoc Local Both
Probing internal representations of a model Post Hoc Global/Local White Box

What explanation should you use when? There is no simple answer, as it depends upon your goals (i.e., why you need an explanation), who the audience is, the model architecture, and the availability of model internals (e.g., there is no white-box access to ChatGPT unless you work for Open AI!). The next exercise asks you to consider different scenarios and discuss what explanation techniques are appropriate.

Challenge

Think about the following scenarios and suggest which explainability method would be most appropriate to use, and what information could be gained from that method. Furthermore, think about the limitations of your findings.

Note: These are open-ended questions, and there is no correct answer. Feel free to break into discussion groups to discuss the scenarios.

Scenario 1: Suppose that you are an ML engineer working at a tech company. A fast-food chain company consults with you about sentimental analysis based on feedback they collected on Yelp and their survey. You use an open sourced LLM such as Llama-2 and finetune it on the review text data. The fast-food company asks to provide explanations for the model: Is there any outlier review? How does each review in the data affect the finetuned model? Which part of the language in the review indicates that a customer likes or dislikes the food? Can you score the food quality according to the reviews? Does the review show a trend over time? What item is gaining popularity or losing popularity? Q: Can you suggest a few explainability methods that may be useful for answering these questions?

Scenario 2: Suppose that you are a radiologist who analyzes medical images of patients with the help of machine learning models. You use black-box models (e.g., CNNs, Vision Transformers) to complement human expertise and get useful information before making high-stake decisions. Which areas of a medical image most likely explains the output of a black-box? Can we visualize and understand what features are captured by the intermediate components of the black-box models? How do we know if there is a distribution shift? How can we tell if an image is an out-of-distribution example? Q: Can you suggest a few explainability methods that may be useful for answering these questions?

Scenario 3: Suppose that you work on genomics and you just collected samples of single-cell data into a table: each row records gene expression levels, and each column represents a single cell. You are interested in scientific hypotheses about evolution of cells. You believe that only a few genes are playing a role in your study. What exploratory data analysis techniques would you use to examine the dataset? How do you check whether there are potential outliers, irregularities in the dataset? You believe that only a few genes are playing a role in your study. What can you do to find the set of most explanatory genes? How do you know if there is clustering, and if there is a trajectory of changes in the cells? Q: Can you explain the decisions you make for each method you use?

Summary


There are many available explanation techniques and they differ along three dimensions: model access (white-box or black-box), explanation scope (global or local), and approach (inherently interpretable or post-hoc). There’s often no objectively-right answer of which explanation technique to use in a given situation, as the different methods have different tradeoffs.

Content from Explainability methods: deep dive


Last updated on 2024-07-31 | Edit this page

Overview

Questions

  • TODO

Objectives

  • TODO

A Deep Dive into Methods for Understanding Model Behaviour


In the previous section, we scratched the surface of explainability methods, introducing you to the broad classes of methods designed to understand different aspects of a model’s behavior.

Now, we will dive deeper into two widely used methods, each one which answers one key question:

What part of my input causes this prediction?


When a model makes a prediction, we often want to know which parts of the input were most important in generating that prediction. This helps confirm if the model is making its predictions for the right reasons. Sometimes, models use features totally unrelated to the task for their prediction - these are known as ‘spurious correlations’. For example, a model might predict that a picture contains a dog because it was taken in a park, and not because there is actually a dog in the picture.

Saliency Maps are among the most simple and popular methods used towards this end. We will be working with a more sophisticated version of this method, known as GradCAM.

Method and Examples

A saliency map is a kind of visualization - it is a heatmap across the input that shows which parts of the input are most important in generating the model’s prediction. They can be calculated using the gradients of a neural network, or by perturbing the input to any ML model and observing how the model reacts to these perturbations. The key intuition is that if a small change in a part of the input causes a large change in the model’s prediction, then that part of the input is important for the prediction. Gradients are useful in this because they provide a signal towards how much the model’s prediction would change if the input was changed slightly.

For example, in an image classification task, a saliency map can be used to highlight the parts of the image that the model is focusing on to make its prediction. In a text classification task, a saliency map can be used to highlight the words or phrases that are most important for the model’s prediction.

GradCAM is an extension of this idea, which uses the gradients of the final layer of a convolutional neural network to generate a heatmap that highlights the important regions of an image. This heatmap can be overlaid on the original image to visualize which parts of the image are most important for the model’s prediction.

Other variants of this method include Integrated Gradients, SmoothGrad, and others, which are designed to provide more robust and reliable explanations for model predictions. However, GradCAM is a good starting point for understanding how saliency maps work, and is a popularly used approach.

Alternative approaches, which may not directly generate heatmaps, include LIME and SHAP, which are also popular and recommended for further reading.

Limitations and Extensions

Gradient based saliency methods like GradCam are fast to compute, requiring only a handful of backpropagation steps on the model to generate the heatmap. The method is also model-agnostic, meaning it can be applied to any model that can be trained using gradient descent. Additionally, the results obtained from these methods are intuitive and easy to understand, making them useful for explaining model predictions to non-experts.

However, their use is limited to models that can be trained using gradient descent, and have white-box access. It is also difficult to apply these methods to tasks beyond classification, making their application limited with many recent generative models (think LLMs).

Another limitation is that the insights gained from these methods are not actionable - knowing which part of the input caused the prediction does not highlight why that part caused it. On finding issues in the prediction process, it is also hard to pick up on if there is an underlying issue in the model, or just the specific inputs tested on.

What part of my model causes this prediction?


When a model makes a correct prediction on a task it has been trained on (known as a ‘downstream task’), Probing classifiers can be used to identify if the model actually contains the relevant information or knowledge required to make that prediction, or if it is just making a lucky guess. Furthermore, probes can be used to identify the specific components of the model that contain this relevant information, providing crucial insights for developing better models over time.

Method and Examples

A neural network takes its input as a series of vectors, or representations, and transforms them through a series of layers to produce an output. The job of the main body of the neural network is to develop representations that are as useful for the downstream task as possible, so that the final few layers of the network can make a good prediction.

This essentially means that a good quality representation is one that already contains all the information required to make a good prediction. In other words, the features or representations from the model are easily separable by a simple classifier. And that classifier is what we call a ‘probe’. A probe is a simple model that uses the representations of the model as input, and tries to learn the downstream task from them. The probe itself is designed to be too easy to learn the task on its own. This means, that the only way the probe get perform well on this task is if the representations it is given are already good enough to make the prediction.

These representations can be taken from any part of the model. Generally, using representations from the last layer of a neural network help identify if the model even contains the information to make predictions for the downstream task. However, this can be extended further: probing the representations from different layers of the model can help identify where in the model the information is stored, and how it is transformed through the model.

Probes have been frequently used in the domain of NLP, where they have been used to check if language models contain certain kinds of linguistic information. These probes can be designed with varying levels of complexity. For example, simple probes have shown language models to contain information about simple syntactical features like Part of Speech tags, and more complex probes have shown models to contain entire Parse trees of sentences.

Limitations and Extensions

One large challenge in using probes is identifying the correct architectural design of the probe. Too simple, and it may not be able to learn the downstream task at all. Too complex, and it may be able to learn the task even if the model does not contain the information required to make the prediction.

Another large limitation is that even if a probe is able to learn the downstream task, it does not mean that the model is actually using the information contained in the representations to make the prediction. So essentially, a probe can only tell us if a part of the model can make the prediction, not if it does make the prediction.

A new approach known as Causal Tracing addresses this limitation. The objective of this approach is similar to probes: attempting to understand which part of a model contains information relevant to a downstream task. The approach involves iterating through all parts of the model being examined (e.g. all layers of a model), and disrupting the information flow through that part of the model. (This could be as easy as adding some kind of noise on top of the weights of that model component). If the model performance on the downstream task suddenly drops on disrupting a specific model component, we know for sure that that component not only contains the information required to make the prediction, but that the model is actually using that information to make the prediction.

Challenge

Now, it’s time to try implementing these methods yourself! Pick one of the following problems to work on:

It’s time to get your hands dirty now. Good luck, and have fun!

Content from Explainability methods: linear probe


Last updated on 2024-07-03 | Edit this page

Overview

Questions

  • TODO

Objectives

  • TODO

PYTHON

# Let's start by importing the necessary libraries.

import os
import torch
import logging
import numpy as np
from typing import Tuple
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from datasets import load_dataset, Dataset
from transformers import AutoModel, AutoTokenizer, AutoConfig

logging.basicConfig(level=logging.INFO)
os.environ['TOKENIZERS_PARALLELISM'] = 'false'  # This is needed to avoid a warning from huggingface

Now, let’s set the random seed to ensure reproducibility. Setting random seeds is like setting a starting point for your machine learning adventure. It ensures that every time you train your model, it starts from the same place, using the same random numbers, making your results consistent and comparable.

PYTHON

# Set random seeds for reproducibility - pick any number of your choice to set the seed. We use 42, since that is the answer to everything, after all.
torch.manual_seed(42)

PYTHON

# Set the GPU to use
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  
Loading the Dataset

Let’s load our data: the IMDB Movie Review dataset. The dataset contains text reviews and their corresponding sentiment labels (positive or negative). The label 1 corresponds to a positive review, and 0 corresponds to a negative review.

PYTHON

def load_imdb_dataset(keep_samples: int = 100) -> Tuple[Dataset, Dataset, Dataset]:
    '''
    Load the IMDB dataset from huggingface.
    The dataset contains text reviews and their corresponding sentiment labels (positive or negative).
    The label 1 corresponds to a positive review, and 0 corresponds to a negative review.
    :param keep_samples: Number of samples to keep, for faster training.
    :return: train, dev, test datasets. Each can be treated as a dictionary with keys 'text' and 'label'.
    '''
    dataset = load_dataset('imdb')

    # Keep only a subset of the data for faster training
    train_dataset = Dataset.from_dict(dataset['train'].shuffle(seed=42)[:keep_samples])
    dev_dataset = Dataset.from_dict(dataset['test'].shuffle(seed=42)[:keep_samples])
    test_dataset = Dataset.from_dict(dataset['test'].shuffle(seed=42)[keep_samples:2*keep_samples])

    # train_dataset[0] will return {'text': ...., 'label': 0}
    logging.info(f'Loaded IMDB dataset: {len(train_dataset)} training samples, {len(dev_dataset)} dev samples, {len(test_dataset)} test samples.')
    return train_dataset, dev_dataset, test_dataset

PYTHON

train_dataset, dev_dataset, test_dataset = load_imdb_dataset(keep_samples=50)
Loading the Model

We will load a model from huggingface, and use this model to get the embeddings for the probe. We use BERT for this example, but feel free to explore other models from huggingface after the exercise.

BERT is a transformer-based model, and is known to perform well on a variety of NLP tasks. The model is pre-trained on a large corpus of text, and can be fine-tuned for specific tasks.

PYTHON

def load_model(model_name: str) -> Tuple[AutoModel, AutoTokenizer]:
    '''
    Load a model from huggingface.
    :param model_name: Check huggingface for acceptable model names.
    :return: Model and tokenizer.
    '''
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    config = AutoConfig.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name, config=config)
    model.config.max_position_embeddings = 128  # Reducing from default 512 to 128 for computational efficiency

    logging.info(f'Loaded model and tokenizer: {model_name} with {model.config.num_hidden_layers} layers, '
                 f'hidden size {model.config.hidden_size} and sequence length {model.config.max_position_embeddings}.')
    return model, tokenizer

PYTHON

# To play around with other models, find a list of models and their model_ids at: https://huggingface.co/models
model, tokenizer = load_model('bert-base-uncased')

Let’s see what the model’s architecture looks like. How many layers does it have?

PYTHON

print(model)

Let’s see if your answer matches the actual number of layers in the model.

PYTHON

num_layers = model.config.num_hidden_layers
print(f'The model has {num_layers} layers.')
Setting up the Probe

Before we define the probing classifier or probe, let’s set up some utility functions the probe will use. The probe will be trained from hidden representations from a specific layer of the BERT model. The get_embeddings_from_model function will retrieve the intermediate layer representations (also known as embeddings) from a user defined layer number.

The visualize_embeddings method can be used to see what these high dimensional hidden embeddings would look like when converted into a 2D view. The visualization is not intended to be informative in itself, and is only an additional tool used to get a sense of what the inputs to the probing classifier may look like.

PYTHON

def get_embeddings_from_model(model: AutoModel, tokenizer: AutoTokenizer, layer_num: int, data: list[str]) -> torch.Tensor:
    '''
    Get the embeddings from a model.
    :param model: The model to use. This is needed to get the embeddings.
    :param tokenizer: The tokenizer to use. This is needed to convert the data to input IDs.
    :param layer_num: The layer to get embeddings from. 0 is the input embeddings, and the last layer is the output embeddings.
    :param data: The data to get embeddings for. A list of strings.
    :return: The embeddings. Shape is N, L, D, where N is the number of samples, L is the length of the sequence, and D is the dimensionality of the embeddings.
    '''
    logging.info(f'Getting embeddings from layer {layer_num} for {len(data)} samples...')

    # Batch the data for computational efficiency
    batch_size = 32
    batch_num = 1
    for i in range(0, len(data), batch_size):
        batch = data[i:i+batch_size]
        logging.info(f'Getting embeddings for batch {batch_num}...')
        batch_num += 1

        # Tokenize the batch of data
        inputs = tokenizer(batch, return_tensors='pt', padding=True, truncation=True)

        # Get the embeddings from the model
        outputs = model(**inputs, output_hidden_states=True)

        # Get the embeddings for the specific the layer
        embeddings = outputs.hidden_states[layer_num]

        # Concatenate the embeddings from each batch
        if i == 0:
            all_embeddings = embeddings
        else:
            all_embeddings = torch.cat([all_embeddings, embeddings], dim=0)

    logging.info(f'Got embeddings for {len(data)} samples from layer {layer_num}. Shape: {all_embeddings.shape}')
    return all_embeddings

PYTHON

def visualize_embeddings(embeddings: torch.Tensor, labels: list, layer_num: int, save_plot: bool = False) -> None:
    '''
    Visualize the embeddings using t-SNE.
    :param embeddings: The embeddings to visualize. Shape is N, L, D, where N is the number of samples, L is the length of the sequence, and D is the dimensionality of the embeddings.
    :param labels: The labels for the embeddings. A list of integers.
    :return: None
    '''

    # Since we are working with sentiment analysis, which is sentence based task, we can use sentence embeddings.
    # The sentence embeddings are simply the mean of the token embeddings of that sentence.
    sentence_embeddings = torch.mean(embeddings, dim=1)  # N, D

    # Convert to numpy
    sentence_embeddings = sentence_embeddings.detach().numpy()
    labels = np.array(labels)

    # Visualize the embeddings using t-SNE
    tsne = TSNE(n_components=2, random_state=0)
    embeddings_2d = tsne.fit_transform(sentence_embeddings)

    negative_points = embeddings_2d[labels == 0]
    positive_points = embeddings_2d[labels == 1]

    # Plot the embeddings. We want to colour the datapoints by label.
    fig, ax = plt.subplots()
    ax.scatter(negative_points[:, 0], negative_points[:, 1], label='Negative', color='red', marker='o', s=10, alpha=0.7)
    ax.scatter(positive_points[:, 0], positive_points[:, 1], label='Positive', color='blue', marker='o', s=10, alpha=0.7)
    plt.xlabel('t-SNE dimension 1')
    plt.ylabel('t-SNE dimension 2')
    plt.title(f't-SNE of Sentence Embeddings - Layer{layer_num}')
    plt.legend()

    # Save the plot if needed, then display it
    if save_plot:
        plt.savefig(f'tsne_layer_{layer_num}.png')
    plt.show()

    logging.info('Visualized embeddings using t-SNE.')

Now, it’s finally time to define our probe! We set this up as a class, where the probe itself is an object of this class. The class also contains methods used to train and evaluate the probe.

Read through this code block in a bit more detail - from this whole exercise, this part provides you with the most useful takeaways on ways to define and train neural networks!

PYTHON

class Probe():
    def __init__(self, hidden_dim: int = 768, class_size: int = 2)  -> None:
        '''
        Initialize the probe.
        :param hidden_dim: The dimensionality of the hidden layer of the probe.
        :param num_layers: The number of layers in the probe.
        :return: None
        '''

        # The probe is a simple linear classifier, with a hidden layer and an output layer.
        # The input to the probe is the embeddings from the model, and the output is the predicted class.

        # Exercise: Try playing around with the hidden_dim and num_layers to see how it affects the probe's performance.
        # But watch out: if a complex probe performs well on the task, we don't know if the performance
        # is because of the model embeddings, or the probe itself learning the task!

        self.probe = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, class_size),

            # Add more layers here if needed

            # Sigmoid is used to convert the hidden states into a probability distribution over the classes
            torch.nn.Sigmoid()
        )


    def train(self, data_embeddings: torch.Tensor, labels: torch.Tensor, num_epochs: int = 10,
              learning_rate: float = 0.001, batch_size: int = 32) -> None:
        '''
        Train the probe on the embeddings of data from the model.
        :param data_embeddings: A tensor of shape N, L, D, where N is the number of samples, L is the length of the sequence, and D is the dimensionality of the embeddings.
        :param labels: A tensor of shape N, where N is the number of samples. Each element is the label for the corresponding sample.
        :param num_epochs: The number of epochs to train the probe for. An epoch is one pass through the entire dataset.
        :param learning_rate: How fast the probe learns. A hyperparameter.
        :param batch_size: Used to batch the data for computational efficiency. A hyperparameter.
        :return:
        '''

        # Setup the loss function (training objective) for the training process.
        # The cross-entropy loss is used for multi-class classification, and represents the negative log likelihood of the true class.
        criterion = torch.nn.CrossEntropyLoss()

        # Setup the optimization algorithm to update the probe's parameters during training.
        # The Adam optimizer is an extension to stochastic gradient descent, and is a popular choice.
        optimizer = torch.optim.Adam(self.probe.parameters(), lr=learning_rate)

        # Train the probe
        logging.info('Training the probe...')
        for epoch in range(num_epochs):  # Pass over the data num_epochs times

            for i in range(0, len(data_embeddings), batch_size):

                # Iterate through one batch of data at a time
                batch_embeddings = data_embeddings[i:i+batch_size].detach()
                batch_labels = labels[i:i+batch_size]

                # Convert to sentence embeddings, since we are performing a sentence classification task
                batch_embeddings = torch.mean(batch_embeddings, dim=1)  # N, D

                # Get the probe's predictions, given the embeddings from the model
                outputs = self.probe(batch_embeddings)

                # Calculate the loss of the predictions, against the true labels
                loss = criterion(outputs, batch_labels)

                # Backward pass - update the probe's parameters
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        logging.info('Trained the probe.')


    def predict(self, data_embeddings: torch.Tensor, batch_size: int = 32) -> torch.Tensor:
        '''
        Get the probe's predictions on the embeddings from the model, for unseen data.
        :param data_embeddings: A tensor of shape N, L, D, where N is the number of samples, L is the length of the sequence, and D is the dimensionality of the embeddings.
        :param batch_size: Used to batch the data for computational efficiency.
        :return: A tensor of shape N, where N is the number of samples. Each element is the predicted class for the corresponding sample.
        '''

        # Iterate through batches
        for i in range(0, len(data_embeddings), batch_size):

            # Iterate through one batch of data at a time
            batch_embeddings = data_embeddings[i:i+batch_size]

            # Get the probe's predictions
            outputs = self.probe(batch_embeddings)

            # Get the predicted class for each sample
            _, predicted = torch.max(outputs, 1)

            # Concatenate the predictions from each batch
            if i == 0:
                all_predicted = predicted
            else:
                all_predicted = torch.cat([all_predicted, predicted], dim=0)

        return all_predicted


    def evaluate(self, data_embeddings: torch.tensor, labels: torch.tensor, batch_size: int = 32) -> float:
        '''
        Evaluate the probe's performance by testing it on unseen data.
        :param data_embeddings: A tensor of shape N, L, D, where N is the number of samples, L is the length of the sequence, and D is the dimensionality of the embeddings.
        :param labels: A tensor of shape N, where N is the number of samples. Each element is the label for the corresponding sample.
        :return: The accuracy of the probe on the unseen data.
        '''

        # Iterate through batches
        for i in range(0, len(data_embeddings), batch_size):

            # Iterate through one batch of data at a time
            batch_embeddings = data_embeddings[i:i+batch_size]
            batch_labels = labels[i:i+batch_size]

            # Convert to sentence embeddings, since we are performing a sentence classification task
            batch_embeddings = torch.mean(batch_embeddings, dim=1)  # N, D

            # Get the probe's predictions
            with torch.no_grad():
                outputs = self.probe(batch_embeddings)

            # Get the predicted class for each sample
            _, predicted = torch.max(outputs, dim=-1)

            # Concatenate the predictions from each batch
            if i == 0:
                all_predicted = predicted
                all_labels = batch_labels
            else:
                all_predicted = torch.cat([all_predicted, predicted], dim=0)
                all_labels = torch.cat([all_labels, batch_labels], dim=0)

        # Calculate the accuracy of the probe
        correct = (all_predicted == all_labels).sum().item()
        accuracy = correct / all_labels.shape[0]
        logging.info(f'Probe accuracy: {accuracy:.2f}')

        return accuracy

PYTHON

# Initialize the probing classifier (or probe)
probe = Probe()
Analysing the model using Probes

Time to start evaluating the model using our probing tool! Let’s see which layer has most information about sentiment analysis on IMDB. For this, we will train the probe on embeddings from each layer of the model, and see which layer performs the best on the dev set.

PYTHON

layer_wise_accuracies = []
best_probe, best_layer, best_accuracy = None, -1, 0

for layer_num in range(num_layers):
    logging.info(f'\n\nEvaluating representations of layer {layer_num+1}...')

    train_embeddings = get_embeddings_from_model(model, tokenizer, layer_num=layer_num, data=train_dataset['text'])
    dev_embeddings = get_embeddings_from_model(model, tokenizer, layer_num=layer_num, data=dev_dataset['text'])
    train_labels, dev_labels = torch.tensor(train_dataset['label'],  dtype=torch.long), torch.tensor(dev_dataset['label'],  dtype=torch.long)

    # Before training the probe, let's visualize the embeddings using t-SNE.
    # If the layer has information about sentiment analysis, would we see some structure in the embeddings?
    # Compare plots from layers where the probe does poorly, with ones where it does well. What do you notice?
    visualize_embeddings(embeddings=train_embeddings, labels=train_dataset['label'], layer_num=layer_num, save_plot=False)

    # Now, let's train the probe on the embeddings from the model.
    # Feel free to play around with the training hyperparameters, and see what works best for your probe.
    probe = Probe()
    probe.train(data_embeddings=train_embeddings, labels=train_labels,
                num_epochs=5, learning_rate=0.001, batch_size=32)

    # Let's see how well our probe does on a held out dev set
    accuracy = probe.evaluate(data_embeddings=dev_embeddings, labels=dev_labels)
    layer_wise_accuracies.append(accuracy)

    # Keep track of the best probe
    if accuracy > best_accuracy:
        best_probe, best_layer, best_accuracy = probe, layer_num, accuracy

PYTHON

# Seeing a list of accuracies can be hard to interpret. Let's plot the layer-wise accuracies to see which layer is best.
plt.plot(layer_wise_accuracies)
plt.xlabel('Layer')
plt.ylabel('Accuracy')
plt.title('Probe Accuracy by Layer')
plt.grid(alpha=0.3)
plt.show()

Which layer has the best accuracy? What does this tell us about the model?

Let’s go ahead and stress test this. Is the best layer able to predict sentiment for sentences outside the IMDB dataset?

For answering this question, you are the test set! Try to think of challenging sequences for which the model may not be able to predict sentiment.

PYTHON

test_sequences = ['Your sentence here', 'Here is another sentence']
embeddings = get_embeddings_from_model(model=model, tokenizer=tokenizer, layer_num=best_layer, data=test_sequences)
preds = probe.predict(data_embeddings=embeddings)
predictions = ['Positive' if pred == 1 else 'Negative' for pred in preds]
print(f'Predictions for test sequences: {predictions}')

Content from Explainability methods: GradCAM


Last updated on 2024-07-03 | Edit this page

Overview

Questions

  • TODO

Objectives

  • TODO

PYTHON

# Let's begin by installing the grad-cam package - this will significantly simplify our implementation
!pip install grad-cam

PYTHON

# Packages to download test images
import requests

# Packages to view and process images
import cv2
import numpy as np
from PIL import Image
from google.colab.patches import cv2_imshow

# Packages to load the model
import torch
from torchvision.models import resnet50

# GradCAM Packaes
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image

PYTHON

device = 'gpu' if torch.cuda.is_available() else 'cpu'
Load Model

We’ll load the ResNet-50 model from torchvision. This model is pre-trained on the ImageNet dataset, which contains 1.2 million images across 1000 classes. ResNet-50 is popular model that is a type of convolutional neural network. You can learn more about it here: https://pytorch.org/hub/pytorch_vision_resnet/

PYTHON

model = resnet50(pretrained=True).to(device).eval()
Load Test Image

PYTHON

# Let's first take a look at the image, which we source from the GradCAM package

url = "https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/master/examples/both.png"
Image.open(requests.get(url, stream=True).raw)

PYTHON

# Cute, isn't it? Do you prefer dogs or cats?

# We will need to convert the image into a tensor to feed it into the model.
# Let's create a function to do this for us.
def load_image(url):
    rgb_img = np.array(Image.open(requests.get(url, stream=True).raw))
    rgb_img = np.float32(rgb_img) / 255
    input_tensor = preprocess_image(rgb_img).to(device)
    return input_tensor, rgb_img

PYTHON

input_tensor, rgb_image = load_image(url)

Grad-CAM Time!

PYTHON

# Let's start by selecting which layers of the model we want to use to generate the CAM.
# For that, we will need to inspect the model architecture.
# We can do that by simply printing the model object.
print(model)

Here we want to interpret what the model as a whole is doing (not what a specific layer is doing). That means that we want to use the embeddings of the last layer before the final classification layer. This is the layer that contains the information about the image encoded by the model as a whole.

Looking at the model, we can see that the last layer before the final classification layer is layer4.

PYTHON

target_layers = [model.layer4]

We also want to pick a label for the CAM - this is the class we want to visualize the activation for. Essentially, we want to see what the model is looking at when it is predicting a certain class.

Since ResNet was trained on the ImageNet dataset with 1000 classes, let’s get an indexed list of those classes. We can then pick the index of the class we want to visualize.

PYTHON

imagenet_categories_url = \
     "https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
labels = eval(requests.get(imagenet_categories_url).text)
labels

Well, that’s a lot! To simplify things, we have already picked out the indices of a few interesting classes.

  • 157: Siberian Husky
  • 162: Beagle
  • 245: French Bulldog
  • 281: Tabby Cat
  • 285: Egyptian cat
  • 360: Otter
  • 537: Dog Sleigh
  • 799: Sliding Door
  • 918: Street Sign

PYTHON

# Specify the target class for visualization here. If you set this to None, the class with the highest score from the model will automatically be used.
visualized_class_id = 245

PYTHON

def viz_gradcam(model, target_layers, class_id):

  if class_id is None:
    targets = None
  else:
    targets = [ClassifierOutputTarget(class_id)]

  cam_algorithm = GradCAM
  with cam_algorithm(model=model, target_layers=target_layers) as cam:
      grayscale_cam = cam(input_tensor=input_tensor,
                          targets=targets)

      grayscale_cam = grayscale_cam[0, :]

      cam_image = show_cam_on_image(rgb_image, grayscale_cam, use_rgb=True)
      cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)

  cv2_imshow(cam_image)

Finally, we can start visualizing! Let’s begin by seeing what parts of the image the model looks at to make its most confident prediction.

PYTHON

viz_gradcam(model=model, target_layers=target_layers, class_id=None)

Interesting, it looks like the model totally ignores the cat and makes a prediction based on the dog. If we set the output class to “French Bulldog” (class_id=245), we see the same visualization - meaning that the model is indeed looking at the correct part of the image to make the correct prediction.

Let’s see what the heatmap looks like when we force the model to look at the cat.

PYTHON

viz_gradcam(model=model, target_layers=target_layers, class_id=281)

The model is indeed looking at the cat when asked to predict the class “Tabby Cat” (class_id=281)! But why is it still predicting the dog? Well, the model was trained on the ImageNet dataset, which contains a lot of images of dogs and cats. The model has learned that the dog is a better indicator of the class “Tabby Cat” than the cat itself.

Let’s see another example of this. The image has not only a dog and a cat, but also a items in the background. Can the model correctly identify the door?

PYTHON

viz_gradcam(model=model, target_layers=target_layers, class_id=799)

It can! However, it seems to also think of the shelf behind the dog as a door.

Let’s try an unrelated object now. Where in the image does the model see a street sign?

PYTHON

viz_gradcam(model=model, target_layers=target_layers, class_id=918)

Looks like our analysis has revealed a shortcoming of the model! It seems to percieve cats and street signs similarly.

Ideally, when the target class is some unrelated object, a good model will look at no significant part of the image. For example, the model does a good job with the class for Dog Sleigh.

PYTHON

viz_gradcam(model=model, target_layers=target_layers, class_id=537)

Explaining model predictions though visualization techniques like this can be very subjective and prone to error. However, this still provides some degree of insight a completely black box model would not provide.

Spend some time playing around with different classes and seeing which part of the image the model looks at. Feel free to play around with other base images as well. Have fun!

Content from Estimating model uncertainty


Last updated on 2024-06-19 | Edit this page

Overview

Questions

  • TODO

Objectives

  • TODO

Key Points

  • TODO

Content from OOD detection: overview, output-based methods


Last updated on 2024-08-14 | Edit this page

Overview

Questions

  • What are out-of-distribution (OOD) data and why is detecting them important in machine learning models?
  • How do output-based methods like softmax and energy-based methods work for OOD detection?
  • What are the limitations of output-based OOD detection methods?

Objectives

  • Understand the concept of out-of-distribution data and its significance in building trustworthy machine learning models.
  • Learn about different output-based methods for OOD detection, including softmax and energy-based methods
  • Identify the strengths and limitations of output-based OOD detection techniques.

Introduction to Out-of-Distribution (OOD) Data

What is OOD data?

Out-of-distribution (OOD) data refers to data that significantly differs from the training data on which a machine learning model was built. The difference can arise from either:

  • Semantic shift: OOD sample is drawn from a class that was not present during training
  • Covariate shift: OOD sample is drawn from a different domain; input feature distribution is drastically different than training data

TODO: Add closed/open-world image similar to Sharon Li’s tutorial at 4:28: https://www.youtube.com/watch?v=hgLC9_9ZCJI

Why does OOD data matter?

Models trained on a specific distribution might make incorrect predictions on OOD data, leading to unreliable outputs. In critical applications (e.g., healthcare, autonomous driving), encountering OOD data without proper handling can have severe consequences.

Ex1: Tesla crashes into jet

In April 2022, a Tesla Model Y crashed into a $3.5 million private jet at an aviation trade show in Spokane, Washington, while operating on the “Smart Summon” feature. The feature allows Tesla vehicles to autonomously navigate parking lots to their owners, but in this case, it resulted in a significant mishap. - The Tesla was summoned by its owner using the Tesla app, which requires holding down a button to keep the car moving. The car continued to move forward even after making contact with the jet, pushing the expensive aircraft and causing notable damage. - The crash highlighted several issues with Tesla’s Smart Summon feature, particularly its object detection capabilities. The system failed to recognize and appropriately react to the presence of the jet, a problem that has been observed in other scenarios where the car’s sensors struggle with objects that are lifted off the ground or have unusual shapes.

Ex2: IBM Watson for Oncology

IBM Watson for Oncology faced several issues due to OOD data. The system was primarily trained on data from Memorial Sloan Kettering Cancer Center (MSK), which did not generalize well to other healthcare settings. This led to the following problems: 1. Unsafe Recommendations: Watson for Oncology provided treatment recommendations that were not safe or aligned with standard care guidelines in many cases outside of MSK. This happened because the training data was not representative of the diverse medical practices and patient populations in different regions 2. Bias in Training Data: The system’s recommendations were biased towards the practices at MSK, failing to account for different treatment protocols and patient needs elsewhere. This bias is a classic example of an OOD issue, where the model encounters data (patients and treatments) during deployment that significantly differ from its training data

Ex3: Doctors using GPT3

Misdiagnosis and Inaccurate Medical Advice

In various studies and real-world applications, GPT-3 has been shown to generate inaccurate medical advice when faced with OOD data. This can be attributed to the fact that the training data, while extensive, does not cover all possible medical scenarios and nuances, leading to hallucinations or incorrect responses when encountering unfamiliar input.

A study published by researchers at Stanford found that GPT-3, even when using retrieval-augmented generation, provided unsupported medical advice in about 30% of its statements. For example, it suggested the use of a specific dosage for a defibrillator based on monophasic technology, while the cited source only discussed biphasic technology, which operates differently.

Fake Medical Literature References

Another critical OOD issue is the generation of fake or non-existent medical references by LLMs. When LLMs are prompted to provide citations for their responses, they sometimes generate references that sound plausible but do not actually exist. This can be particularly problematic in academic and medical contexts where accurate sourcing is crucial.

In evaluations of GPT-3’s ability to generate medical literature references , it was found that a significant portion of the references were either entirely fabricated or did not support the claims being made. This was especially true for complex medical inquiries that the model had not seen in its training data.

Detecting and Handling OOD Data

Given the problems posed by OOD data, a reliable model should identify such instances, and then:

  1. Reject them during inference
  2. Ideally, hand these OOD instances to a model trained on a more similar distribution (an in-distribution).

The second step is much more complicated/involved since it requires matching OOD data to essentially an infinite number of possible classes. For the current scope of this workshop, we will focus on just the first step.

How can we determine whether a given instance is OOD or ID? Over the past several years, there have been a wide assortment of new methods developed to tackle this task. In this episode, we will cover a few of the most common approaches and discuss advantages/disadvantages of each.

Threshold-based methods

Threshold-based methods are one of the simplest and most intuitive approaches for detecting out-of-distribution (OOD) data. The central idea is to define a threshold on a certain score or confidence measure, beyond which the data point is considered out-of-distribution. Typically, these scores are derived from the model’s output probabilities or other statistical measures of uncertainty. There are two general classes of threshold-based methods: output-based and distance-based.

Output-based thresholds

Output-based Out-of-Distribution (OOD) detection refers to methods that determine whether a given input is out-of-distribution based on the output of a trained model. These methods typically analyze the model’s confidence scores, energy scores, or other output metrics to identify data points that are unlikely to belong to the distribution the model was trained on. The main approaches within output-based OOD detection include:

  • Softmax scores: The softmax output of a neural network represents the predicted probabilities for each class. A common threshold-based method involves setting a confidence threshold, and if the maximum softmax score of an instance falls below this threshold, it is flagged as OOD.
  • Energy: The energy-based method also uses the network’s output but measures the uncertainty in a more nuanced way by calculating an energy score. The energy score typically captures the confidence more robustly, especially in high-dimensional spaces, and can be considered a more general and reliable approach than just using softmax probabilities.

Distance-based thresholds

Distance-based methods calculate the distance of an instance from the distribution of training data features learned by the model. If the distance is beyond a certain threshold, the instance is considered OOD. Common distance-based approaches include:

  • Mahalanobis distance: This method calculates the Mahalanobis distance of a data point from the mean of the training data distribution. A high Mahalanobis distance indicates that the instance is likely OOD.
  • K-nearest neighbors (KNN): This method involves computing the distance to the k-nearest neighbors in the training data. If the average distance to these neighbors is high, the instance is considered OOD.

We will focus on output-based methods (softmax and energy) in this episode and then do a deep dive into distance-based methods in the next episode.

Example 1: Softmax scores

Softmax-based out-of-distribution (OOD) detection methods are a fundamental aspect of understanding how models differentiate between in-distribution and OOD data. Even though energy-based methods are becoming more popular, grasping softmax OOD detection methods provides essential scaffolding for learning more advanced techniques. Furthermore, softmax thresholding is still in use throughout ML literature, and learning more about this method will help you better assess results from others.

In this first example, we will train a simple logistic regression model to classify images as T-shirts or pants. We will then evaluate how our model reacts to data outside of these two classes (“semantic shift”).

PYTHON

# some settings I'm playing around with when designing this lesson
verbose = False
alpha=0.2
max_iter = 10 # increase after testing phase
n_epochs = 10 # increase after testing phase

Prepare the ID (train and test) and OOD data

  • ID = T-shirts/Blouses, Pants
  • OOD = any other class. For Illustrative purposes, we’ll focus on images of sandals as the OOD class.

PYTHON

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from keras.datasets import fashion_mnist

def prep_ID_OOD_datasests(ID_class_labels, OOD_class_labels):
    # Load Fashion MNIST dataset
    (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
    
    # Prepare OOD data: Sandals = 5
    ood_filter = np.isin(test_labels, OOD_class_labels)
    ood_data = test_images[ood_filter]
    ood_labels = test_labels[ood_filter]
    print(f'ood_data.shape={ood_data.shape}')
    
    # Filter data for T-shirts (0) and Trousers (1) as in-distribution
    train_filter = np.isin(train_labels, ID_class_labels)
    test_filter = np.isin(test_labels, ID_class_labels)
    
    train_data = train_images[train_filter]
    train_labels = train_labels[train_filter]
    print(f'train_data.shape={train_data.shape}')
    
    test_data = test_images[test_filter]
    test_labels = test_labels[test_filter]
    print(f'test_data.shape={test_data.shape}')

    return ood_data, train_data, test_data


def plot_data_sample(train_data, ood_data):
    """
    Plots a sample of in-distribution and OOD data.

    Parameters:
    - train_data: np.array, array of in-distribution data images
    - ood_data: np.array, array of out-of-distribution data images

    Returns:
    - fig: matplotlib.figure.Figure, the figure object containing the plots
    """
    fig = plt.figure(figsize=(10, 4))
    for i in range(5):
        plt.subplot(2, 5, i + 1)
        plt.imshow(train_data[i], cmap='gray')
        plt.title("In-Dist")
        plt.axis('off')
    for i in range(5):
        plt.subplot(2, 5, i + 6)
        plt.imshow(ood_data[i], cmap='gray')
        plt.title("OOD")
        plt.axis('off')
    
    return fig

PYTHON

ood_data, train_data, test_data = prep_ID_OOD_datasests([0,1], [5])
fig = plot_data_sample(train_data, ood_data)
fig.savefig('../images/OOD-detection_image-data-preview.png', dpi=300, bbox_inches='tight')
plt.show()
Preview of image dataset
Preview of image dataset

Visualizing OOD and ID data

PCA

PCA visualization can provide insights into how well a model is separating ID and OOD data. If the OOD data overlaps significantly with ID data in the PCA space, it might indicate that the model could struggle to correctly identify OOD samples.

Focus on Linear Relationships: PCA is a linear dimensionality reduction technique. It assumes that the directions of maximum variance in the data can be captured by linear combinations of the original features. This can be a limitation when the data has complex, non-linear relationships, as PCA may not capture the true structure of the data. However, if you’re using a linear model (as we are here), PCA can be more appropriate for visualizing in-distribution (ID) and out-of-distribution (OOD) data because both PCA and linear models operate under linear assumptions. PCA will effectively capture the main variance in the data as seen by the linear model, making it easier to understand the decision boundaries and how OOD data deviates from the ID data within those boundaries.

PYTHON

# Flatten images for PCA and logistic regression
train_data_flat = train_data.reshape((train_data.shape[0], -1))
test_data_flat = test_data.reshape((test_data.shape[0], -1))
ood_data_flat = ood_data.reshape((ood_data.shape[0], -1))

print(f'train_data_flat.shape={train_data_flat.shape}')
print(f'test_data_flat.shape={test_data_flat.shape}')
print(f'ood_data_flat.shape={ood_data_flat.shape}')

PYTHON

# Perform PCA to visualize the first two principal components
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
train_data_pca = pca.fit_transform(train_data_flat)
test_data_pca = pca.transform(test_data_flat)
ood_data_pca = pca.transform(ood_data_flat)

# Plotting PCA components
plt.figure(figsize=(10, 6))
scatter1 = plt.scatter(train_data_pca[train_labels == 0, 0], train_data_pca[train_labels == 0, 1], c='blue', label='T-shirt/top (ID)', alpha=0.5)
scatter2 = plt.scatter(train_data_pca[train_labels == 1, 0], train_data_pca[train_labels == 1, 1], c='red', label='Pants (ID)', alpha=0.5)
scatter3 = plt.scatter(ood_data_pca[:, 0], ood_data_pca[:, 1], c='green', label='Sandals (OOD)', edgecolor='k')

# Create a single legend for all classes
plt.legend(handles=[scatter1, scatter2, scatter3], loc="upper right")
plt.xlabel('First Principal Component')
plt.ylabel('Second Principal Component')
plt.title('PCA of In-Distribution and OOD Data')
plt.savefig('../images/OOD-detection_PCA-image-dataset.png', dpi=300, bbox_inches='tight')
plt.show()

PCA visualization From this plot, we see that sandals are more likely to be confused as T-shirts than pants. It also may be surprising to see that these data clouds overlap so much given their semantic differences. Why might this be?

  • Over-reliance on linear relationships: Part of this has to do with the fact that we’re only looking at linear relationships and treating each pixel as its own input feature, which is usually never a great idea when working with image data. In our next example, we’ll switch to the more modern approach of CNNs.
  • Semantic gap != feature gap: Another factor of note is that images that have a wide semantic gap may not necessarily translate to a wide gap in terms of the data’s visual features (e.g., ankle boots and bags might both be small, have leather, and have zippers). Part of an effective OOD detection scheme involves thinking carefully about what sorts of data contanimations may be observed by the model, and assessing how similar these contaminations may be to your desired class labels. ## Train and evaluate model on ID data

PYTHON

# Train a logistic regression classifier
model = LogisticRegression(max_iter=max_iter, solver='lbfgs', multi_class='multinomial').fit(train_data_flat, train_labels)

Before we worry about the impact of OOD data, let’s first verify that we have a reasonably accurate model for the ID data.

PYTHON

# Evaluate the model on in-distribution data
in_dist_preds = model.predict(test_data_flat)
in_dist_accuracy = accuracy_score(test_labels, in_dist_preds)
print(f'In-Distribution Accuracy: {in_dist_accuracy:.2f}')

PYTHON

from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay

# Generate and display confusion matrix
cm = confusion_matrix(test_labels, in_dist_preds, labels=[0, 1])
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['T-shirt/top', 'Pants'])
disp.plot(cmap=plt.cm.Blues)
plt.savefig('../images/OOD-detection_ID-confusion-matrix.png', dpi=300, bbox_inches='tight')
plt.show()
ID confusion matrix
ID confusion matrix

How does our model view OOD data?

A basic question we can start with is to ask, on average, how are OOD samples classified? Are they more likely to be Tshirts or pants? For this kind of question, we can calculate the probability scores for the OOD data, and compare this to the ID data.

PYTHON

# Predict probabilities using the model on OOD data (Sandals)
ood_probs = model.predict_proba(ood_data_flat)
avg_ood_prob = np.mean(ood_probs, 0)
print(f"Avg. probability of sandal being T-shirt: {avg_ood_prob[0]:.4f}")
print(f"Avg. probability of sandal being pants: {avg_ood_prob[1]:.4f}")

id_probs = model.predict_proba(train_data_flat)
id_probs_shirts = id_probs[train_labels==0,:]
id_probs_pants = id_probs[train_labels==1,:]
avg_tshirt_prob = np.mean(id_probs_shirts, 0)
avg_pants_prob = np.mean(id_probs_pants, 0)

print()
print(f"Avg. probability of T-shirt being T-shirt: {avg_tshirt_prob[0]:.4f}")
print(f"Avg. probability of pants being pants: {avg_pants_prob[1]:.4f}")

Based on the difference in averages here, it looks like softmax may provide at least a somewhat useful signal in separating ID and OOD data. Let’s take a closer look by plotting histograms of all probability scores across our classes of interest (ID-Tshirt, ID-Pants, and OOD).

PYTHON

# Creating the figure and subplots
fig, axes = plt.subplots(1, 3, figsize=(15, 4), sharey=False)
bins=60
# Plotting the histogram of probabilities for OOD data (Sandals)
axes[0].hist(ood_probs[:, 0], bins=bins, alpha=0.5, label='T-shirt probability')
axes[0].set_xlabel('Probability')
axes[0].set_ylabel('Frequency')
axes[0].set_title('OOD Data (Sandals)')
axes[0].legend()

# Plotting the histogram of probabilities for ID data (T-shirt)
axes[1].hist(id_probs_shirts[:, 0], bins=bins, alpha=0.5, label='T-shirt probability', color='orange')
axes[1].set_xlabel('Probability')
axes[1].set_title('ID Data (T-shirt/top)')
axes[1].legend()

# Plotting the histogram of probabilities for ID data (Pants)
axes[2].hist(id_probs_pants[:, 1], bins=bins, alpha=0.5, label='Pants probability', color='green')
axes[2].set_xlabel('Probability')
axes[2].set_title('ID Data (Pants)')
axes[2].legend()

# Adjusting layout
plt.tight_layout()
plt.savefig('../images/OOD-detection_histograms.png', dpi=300, bbox_inches='tight')
# Displaying the plot
plt.show()

Histograms of ID oand OOD data Alternatively, for a better comparison across all three classes, we can use a probability density plot. This will allow for an easier comparison when the counts across classes lie on vastly different sclaes (i.e., max of 35 vs max of 5000).

PYTHON

from scipy.stats import gaussian_kde

# Create figure
plt.figure(figsize=(10, 6))

# Define bins
alpha = 0.4

# Plot PDF for ID T-shirt (T-shirt probability)
density_id_shirts = gaussian_kde(id_probs_shirts[:, 0])
x_id_shirts = np.linspace(0, 1, 1000)
plt.plot(x_id_shirts, density_id_shirts(x_id_shirts), label='ID T-shirt (T-shirt probability)', color='orange', alpha=alpha)

# Plot PDF for ID Pants (Pants probability)
density_id_pants = gaussian_kde(id_probs_pants[:, 0])
x_id_pants = np.linspace(0, 1, 1000)
plt.plot(x_id_pants, density_id_pants(x_id_pants), label='ID Pants (T-shirt probability)', color='green', alpha=alpha)

# Plot PDF for OOD (T-shirt probability)
density_ood = gaussian_kde(ood_probs[:, 0])
x_ood = np.linspace(0, 1, 1000)
plt.plot(x_ood, density_ood(x_ood), label='OOD (T-shirt probability)', color='blue', alpha=alpha)

# Adding labels and title
plt.xlabel('Probability')
plt.ylabel('Density')
plt.title('Probability Density Distributions for OOD and ID Data')
plt.legend()

plt.savefig('../images/OOD-detection_PSDs.png', dpi=300, bbox_inches='tight')

# Displaying the plot
plt.show()

Probability densities Unfortunately, we observe a significant amount of overlap between OOD data and high T-shirt probability. Furthermore, the blue line doesn’t seem to decrease much as you move from 0.9 to 1, suggesting that even a very high threshold is likely to lead to OOD contamination (while also tossing out a significant portion of ID data).

For pants, the problem is much less severe. It looks like a low threshold (on this T-shirt probability scale) can separate nearly all OOD samples from being pants.

Setting a threshold

Let’s put our observations to the test and produce a confusion matrix that includes ID-pants, ID-Tshirts, and OOD class labels. We’ll start with a high threshold of 0.9 to see how that performs.

PYTHON

def softmax_thresh_classifications(probs, threshold):
    classifications = np.where(probs[:, 1] >= threshold, 1,  # classified as pants
                               np.where(probs[:, 0] >= threshold, 0,  # classified as shirts
                                        -1))  # classified as OOD
    return classifications

PYTHON

from sklearn.metrics import precision_recall_fscore_support

# Assuming ood_probs, id_probs, and train_labels are defined
# Threshold values
upper_threshold = 0.9

# Classifying OOD examples (sandals)
ood_classifications = softmax_thresh_classifications(ood_probs, upper_threshold)

# Classifying ID examples (T-shirts and pants)
id_classifications = softmax_thresh_classifications(id_probs, upper_threshold)

# Combine OOD and ID classifications and true labels
all_predictions = np.concatenate([ood_classifications, id_classifications])
all_true_labels = np.concatenate([-1 * np.ones(ood_classifications.shape), train_labels])

# Confusion matrix
cm = confusion_matrix(all_true_labels, all_predictions, labels=[0, 1, -1])

# Plotting the confusion matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Shirt", "Pants", "OOD"])
disp.plot(cmap=plt.cm.Blues)
plt.title('Confusion Matrix for OOD and ID Classification')

plt.savefig('../images/OOD-detection_ID-OOD-confusion-matrix1.png', dpi=300, bbox_inches='tight')

plt.show()

# Looking at F1, precision, and recall
precision, recall, f1, _ = precision_recall_fscore_support(all_true_labels, all_predictions, labels=[0, 1], average='macro') # discuss macro vs micro .

print(f"F1: {f1}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")

Probability densities Even with a high threshold of 0.9, we end up with nearly a couple hundred OOD samples classified as ID. In addition, over 800 ID samples had to be tossed out due to uncertainty.

Quick exercise

What threhsold is required to ensure that no OOD samples are incorrectly considered as IID? What percentage of ID samples are mistaken as OOD at this threshold? Answer: 0.9999, (3826+2414)/(3826+2414+2174+3586)=52%

With a very conservative threshold, we can make sure very few OOD samples are incorrectly classified as ID. However, the flip side is that conservative thresholds tend to incorrectly classify many ID samples as being OOD. In this case, we incorrectly assume almost 20% of shirts are OOD samples.

Iterative Threshold Determination

In practice, selecting an appropriate threshold is an iterative process that balances the trade-off between correctly identifying in-distribution (ID) data and accurately flagging out-of-distribution (OOD) data. Here’s how you can iteratively determine the threshold:

  • Define Evaluation Metrics: While confusion matrices are an excellent tool when you’re ready to more closely examine the data, we need a single metric that can summarize threshold performance so we can easily compare across threshold. Common metrics include accuracy, precision, recall, or the F1 score for both ID and OOD detection.

  • Evaluate Over a Range of Thresholds: Test different threshold values and evaluate the performance on a validation set containing both ID and OOD data.

  • Select the Optimal Threshold: Choose the threshold that provides the best balance according to your chosen metrics.

Use the below code to determine what threshold should be set to ensure precision = 100%. What threshold is required for recall to be 100%? What threshold gives the highest F1 score?

Callout on averaging schemes

F1 scores can be calculated per class, and then averaged in different ways (macro, micro, or weighted) when dealing with multiclass or multilabel classification problems. Here are the key types of averaging methods:

  • Macro-Averaging: Calculates the F1 score for each class independently and then takes the average of these scores. This treats all classes equally, regardless of their support (number of true instances for each class).

  • Micro-Averaging: Aggregates the contributions of all classes to compute the average F1 score. This is typically used for imbalanced datasets as it gives more weight to classes with more instances.

  • Weighted-Averaging: Calculates the F1 score for each class independently and then takes the average, weighted by the number of true instances for each class. This accounts for class imbalance by giving more weight to classes with more instances.

Callout on including OOD data in F1 calculation

PYTHON

# from sklearn.metrics import precision_recall_fscore_support, accuracy_score

def eval_softmax_thresholds(thresholds, ood_probs, id_probs):
    # Store evaluation metrics for each threshold
    precisions = []
    recalls = []
    f1_scores = []
    
    for threshold in thresholds:
        # Classifying OOD examples (sandals)
        ood_classifications = softmax_thresh_classifications(ood_probs, threshold)
        
        # Classifying ID examples (T-shirts and pants)
        id_classifications = softmax_thresh_classifications(id_probs, threshold)
        
        # Combine OOD and ID classifications and true labels
        all_predictions = np.concatenate([ood_classifications, id_classifications])
        all_true_labels = np.concatenate([-1 * np.ones(ood_classifications.shape), train_labels])
        
        # Evaluate metrics
        precision, recall, f1, _ = precision_recall_fscore_support(all_true_labels, all_predictions, labels=[0, 1], average='macro') # discuss macro vs micro .
        precisions.append(precision)
        recalls.append(recall)
        f1_scores.append(f1)
        
    return precisions, recalls, f1_scores

PYTHON

# Define thresholds to evaluate
thresholds = np.linspace(.5, 1, 50)

# Evaluate on all thresholds
precisions, recalls, f1_scores = eval_softmax_thresholds(thresholds, ood_probs, id_probs)

PYTHON

def plot_metrics_vs_thresholds(thresholds, f1_scores, precisions, recalls, OOD_signal):
    # Find the best thresholds for each metric
    best_f1_index = np.argmax(f1_scores)
    best_f1_threshold = thresholds[best_f1_index]
    
    best_precision_index = np.argmax(precisions)
    best_precision_threshold = thresholds[best_precision_index]
    
    best_recall_index = np.argmax(recalls)
    best_recall_threshold = thresholds[best_recall_index]
    
    print(f"Best F1 threshold: {best_f1_threshold}, F1 Score: {f1_scores[best_f1_index]}")
    print(f"Best Precision threshold: {best_precision_threshold}, Precision: {precisions[best_precision_index]}")
    print(f"Best Recall threshold: {best_recall_threshold}, Recall: {recalls[best_recall_index]}")

    # Create a new figure
    fig, ax = plt.subplots(figsize=(12, 8))

    # Plot metrics as functions of the threshold
    ax.plot(thresholds, precisions, label='Precision', color='g')
    ax.plot(thresholds, recalls, label='Recall', color='b')
    ax.plot(thresholds, f1_scores, label='F1 Score', color='r')
    
    # Add best threshold indicators
    ax.axvline(x=best_f1_threshold, color='r', linestyle='--', label=f'Best F1 Threshold: {best_f1_threshold:.2f}')
    ax.axvline(x=best_precision_threshold, color='g', linestyle='--', label=f'Best Precision Threshold: {best_precision_threshold:.2f}')
    ax.axvline(x=best_recall_threshold, color='b', linestyle='--', label=f'Best Recall Threshold: {best_recall_threshold:.2f}')
    ax.set_xlabel(f'{OOD_signal} Threshold')
    ax.set_ylabel('Metric Value')
    ax.set_title('Evaluation Metrics as Functions of Threshold')
    ax.legend()

    return fig, best_f1_threshold, best_precision_threshold, best_recall_threshold

PYTHON

fig, best_f1_threshold, best_precision_threshold, best_recall_threshold = plot_metrics_vs_thresholds(thresholds, f1_scores, precisions, recalls, 'Softmax')
fig.savefig('../images/OOD-detection_metrics_vs_softmax-thresholds.png', dpi=300, bbox_inches='tight')
OOD-detection_metrics_vs_softmax-thresholds
OOD-detection_metrics_vs_softmax-thresholds

PYTHON

# Threshold values
upper_threshold = best_f1_threshold
# upper_threshold = best_precision_threshold

# Classifying OOD examples (sandals)
ood_classifications = softmax_thresh_classifications(ood_probs, upper_threshold)

# Classifying ID examples (T-shirts and pants)
id_classifications = softmax_thresh_classifications(id_probs, upper_threshold)

# Combine OOD and ID classifications and true labels
all_predictions = np.concatenate([ood_classifications, id_classifications])
all_true_labels = np.concatenate([-1 * np.ones(ood_classifications.shape), train_labels])

# Confusion matrix
cm = confusion_matrix(all_true_labels, all_predictions, labels=[0, 1, -1])

# Plotting the confusion matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Shirt", "Pants", "OOD"])
disp.plot(cmap=plt.cm.Blues)
plt.title('Confusion Matrix for OOD and ID Classification')
plt.savefig('../images/OOD-detection_ID-OOD-confusion-matrix2.png', dpi=300, bbox_inches='tight')
plt.show()
Optimized threshold confusion matrix
Optimized threshold confusion matrix

Example 2: Energy-Based OOD Detection

TODO: Provide background and intuiiton surrounding energy-based measure. Some notes below:

Liu et al., Energy-based Out-of-distribution Detection, NeurIPS 2020; https://arxiv.org/pdf/2010.03759

  • E(x, y) = energy value

  • if x and y are “compatitble”, lower energy

  • Energy can be turned into probability through Gibbs distribution

    • looks at integral over all possible y’s
  • With energy scores, ID and OOD distributions become much more separable

  • Another “output-based” method like softmax

  • I believe this measure is explicitly designed to work with neural nets, but may (?) work with other models

Introducing PyTorch OOD

The PyTorch-OOD library provides methods for OOD detection and other closely related fields, such as anomoly detection or novelty detection. Visit the docs to learn more: pytorch-ood.readthedocs.io/en/latest/info.html

This library will provide a streamlined way to calculate both energy and softmax scores from a trained model. ### Setup example In this example, we will train a CNN model on the FashionMNIST dataset. We will then repeat a similar process as we did with softmax scores to evaluate how well the energy metric can separate ID and OOD data.

We’ll start by fresh by loading our data again. This time, let’s treat all remaining classes in the MNIST fashion dataset as OOD. This should yield a more robust model that is more reliable when presented with all kinds of data.

PYTHON

ood_data, train_data, test_data = prep_ID_OOD_datasests([0,1], list(range(2,10))) # use remaining 8 classes in dataset as OOD
fig = plot_data_sample(train_data, ood_data)
fig.savefig('../images/OOD-detection_image-data-preview.png', dpi=300, bbox_inches='tight')
plt.show()

Visualizing OOD and ID data

UMAP (or similar)

Recall in our previous example, we used PCA to visualize the ID and OOD data distributions. This was appropriate given that we were evaluating OOD/ID data in the context of a linear model. However, when working with nonlinear models such as CNNs, it makes more sense to investigate how the data is represented in a nonlinear space. Nonlinear embedding methods, such as Uniform Manifold Approximation and Projection (UMAP), are more suitable in such scenarios.

UMAP is a non-linear dimensionality reduction technique that preserves both the global structure and the local neighborhood relationships in the data. UMAP is often better at maintaining the continuity of data points that lie on non-linear manifolds. It can reveal nonlinear patterns and structures that PCA might miss, making it a valuable tool for analyzing ID and OOD distributions.

PYTHON

plot_umap = True # leave off for now to save time testing downstream materials
if plot_umap:
    import umap
    # Flatten images for PCA and logistic regression
    train_data_flat = train_data.reshape((train_data.shape[0], -1))
    test_data_flat = test_data.reshape((test_data.shape[0], -1))
    ood_data_flat = ood_data.reshape((ood_data.shape[0], -1))
    
    print(f'train_data_flat.shape={train_data_flat.shape}')
    print(f'test_data_flat.shape={test_data_flat.shape}')
    print(f'ood_data_flat.shape={ood_data_flat.shape}')
    
    # Perform UMAP to visualize the data
    umap_reducer = umap.UMAP(n_components=2, random_state=42)
    combined_data = np.vstack([train_data_flat, ood_data_flat])
    combined_labels = np.hstack([train_labels, np.full(ood_data_flat.shape[0], 2)])  # Use 2 for OOD class
    
    umap_results = umap_reducer.fit_transform(combined_data)
    
    # Split the results back into in-distribution and OOD data
    umap_in_dist = umap_results[:len(train_data_flat)]
    umap_ood = umap_results[len(train_data_flat):]

PYTHON

if plot_umap:
    umap_alpha = .02

    # Plotting UMAP components
    plt.figure(figsize=(10, 6))
    
    # Plot in-distribution data
    scatter1 = plt.scatter(umap_in_dist[train_labels == 0, 0], umap_in_dist[train_labels == 0, 1], c='blue', label='T-shirts (ID)', alpha=umap_alpha)
    scatter2 = plt.scatter(umap_in_dist[train_labels == 1, 0], umap_in_dist[train_labels == 1, 1], c='red', label='Trousers (ID)', alpha=umap_alpha)
    
    # Plot OOD data
    scatter3 = plt.scatter(umap_ood[:, 0], umap_ood[:, 1], c='green', label='OOD', edgecolor='k', alpha=alpha)
    
    # Create a single legend for all classes
    plt.legend(handles=[scatter1, scatter2, scatter3], loc="upper right")
    plt.xlabel('First UMAP Component')
    plt.ylabel('Second UMAP Component')
    plt.title('UMAP of In-Distribution and OOD Data')
    plt.show()

Train CNN

PYTHON

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torch.nn.functional as F

# Convert to PyTorch tensors and normalize
train_data_tensor = torch.tensor(train_data, dtype=torch.float32).unsqueeze(1) / 255.0
test_data_tensor = torch.tensor(test_data, dtype=torch.float32).unsqueeze(1) / 255.0
ood_data_tensor = torch.tensor(ood_data, dtype=torch.float32).unsqueeze(1) / 255.0

train_labels_tensor = torch.tensor(train_labels, dtype=torch.long)
test_labels_tensor = torch.tensor(test_labels, dtype=torch.long)

train_dataset = torch.utils.data.TensorDataset(train_data_tensor, train_labels_tensor)
test_dataset = torch.utils.data.TensorDataset(test_data_tensor, test_labels_tensor)
ood_dataset = torch.utils.data.TensorDataset(ood_data_tensor, torch.zeros(ood_data_tensor.shape[0], dtype=torch.long))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
ood_loader = torch.utils.data.DataLoader(ood_dataset, batch_size=64, shuffle=False)

# Define a simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(64*5*5, 128)  # Updated this line
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 64*5*5)  # Updated this line
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_model(model, train_loader, criterion, optimizer, epochs=5):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')

train_model(model, train_loader, criterion, optimizer)

The warning message indicates that UMAP has overridden the n_jobs parameter to 1 due to the random_state being set. This behavior ensures reproducibility by using a single job. If you want to avoid the warning and still use parallelism, you can remove the random_state parameter. However, removing random_state will mean that the results might not be reproducible.

PYTHON

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Function to plot confusion matrix
def plot_confusion_matrix(labels, predictions, title):
    cm = confusion_matrix(labels, predictions, labels=[0, 1])
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["T-shirt/top", "Trouser"])
    disp.plot(cmap=plt.cm.Blues)
    plt.title(title)
    plt.show()

# Function to evaluate model on a dataset
def evaluate_model(model, dataloader, device):
    model.eval()
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(preds.cpu().numpy())
    return np.array(all_labels), np.array(all_predictions)

# Evaluate on train data
train_labels, train_predictions = evaluate_model(model, train_loader, device)
plot_confusion_matrix(train_labels, train_predictions, "Confusion Matrix for Train Data")

# Evaluate on test data
test_labels, test_predictions = evaluate_model(model, test_loader, device)
plot_confusion_matrix(test_labels, test_predictions, "Confusion Matrix for Test Data")

# Evaluate on OOD data
ood_labels, ood_predictions = evaluate_model(model, ood_loader, device)
plot_confusion_matrix(ood_labels, ood_predictions, "Confusion Matrix for Test Data")

PYTHON

from scipy.stats import gaussian_kde
from pytorch_ood.detector import EnergyBased
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

# Compute softmax scores
def get_softmax_scores(model, dataloader):
    model.eval()
    softmax_scores = []
    with torch.no_grad():
        for inputs, _ in dataloader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            softmax = torch.nn.functional.softmax(outputs, dim=1)
            softmax_scores.extend(softmax.cpu().numpy())
    return np.array(softmax_scores)

id_softmax_scores = get_softmax_scores(model, test_loader)
ood_softmax_scores = get_softmax_scores(model, ood_loader)

# Initialize the energy-based OOD detector
energy_detector = EnergyBased(model, t=1.0)

# Compute energy scores
def get_energy_scores(detector, dataloader):
    scores = []
    detector.model.eval()
    with torch.no_grad():
        for inputs, _ in dataloader:
            inputs = inputs.to(device)
            score = detector.predict(inputs)
            scores.extend(score.cpu().numpy())
    return np.array(scores)

id_energy_scores = get_energy_scores(energy_detector, test_loader)
ood_energy_scores = get_energy_scores(energy_detector, ood_loader)

import matplotlib.pyplot as plt


# Plot PSDs

# Function to plot PSD
def plot_psd(id_scores, ood_scores, method_name):
    plt.figure(figsize=(12, 6))
    alpha = 0.3

    # Plot PSD for ID scores
    id_density = gaussian_kde(id_scores)
    x_id = np.linspace(id_scores.min(), id_scores.max(), 1000)
    plt.plot(x_id, id_density(x_id), label=f'ID ({method_name})', color='blue', alpha=alpha)

    # Plot PSD for OOD scores
    ood_density = gaussian_kde(ood_scores)
    x_ood = np.linspace(ood_scores.min(), ood_scores.max(), 1000)
    plt.plot(x_ood, ood_density(x_ood), label=f'OOD ({method_name})', color='red', alpha=alpha)

    plt.xlabel('Score')
    plt.ylabel('Density')
    plt.title(f'Probability Density Distributions for {method_name} Scores')
    plt.legend()
    plt.show()

# Plot PSD for softmax scores
plot_psd(id_softmax_scores[:, 1], ood_softmax_scores[:, 1], 'Softmax')

# Plot PSD for energy scores
plot_psd(id_energy_scores, ood_energy_scores, 'Energy')

PYTHON

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix, ConfusionMatrixDisplay

# Define thresholds to evaluate
thresholds = np.linspace(id_energy_scores.min(), id_energy_scores.max(), 50)

# Store evaluation metrics for each threshold
accuracies = []
precisions = []
recalls = []
f1_scores = []

# True labels for OOD data (since they are not part of the original labels)
ood_true_labels = np.full(len(ood_energy_scores), -1)

# We need the test_labels to be aligned with the ID data
id_true_labels = test_labels[:len(id_energy_scores)]

for threshold in thresholds:
    # Classify OOD examples based on energy scores
    ood_classifications = np.where(ood_energy_scores >= threshold, -1,  # classified as OOD
                                   np.where(ood_energy_scores < threshold, 0, -1))  # classified as ID

    # Classify ID examples based on energy scores
    id_classifications = np.where(id_energy_scores >= threshold, -1,  # classified as OOD
                                  np.where(id_energy_scores < threshold, id_true_labels, -1))  # classified as ID

    # Combine OOD and ID classifications and true labels
    all_predictions = np.concatenate([ood_classifications, id_classifications])
    all_true_labels = np.concatenate([ood_true_labels, id_true_labels])

    # Evaluate metrics
    precision, recall, f1, _ = precision_recall_fscore_support(all_true_labels, all_predictions, labels=[0, 1], average='macro')#, zero_division=0)
    accuracy = accuracy_score(all_true_labels, all_predictions)

    accuracies.append(accuracy)
    precisions.append(precision)
    recalls.append(recall)
    f1_scores.append(f1)

# Find the best thresholds for each metric
best_f1_index = np.argmax(f1_scores)
best_f1_threshold = thresholds[best_f1_index]

best_precision_index = np.argmax(precisions)
best_precision_threshold = thresholds[best_precision_index]

best_recall_index = np.argmax(recalls)
best_recall_threshold = thresholds[best_recall_index]

print(f"Best F1 threshold: {best_f1_threshold}, F1 Score: {f1_scores[best_f1_index]}")
print(f"Best Precision threshold: {best_precision_threshold}, Precision: {precisions[best_precision_index]}")
print(f"Best Recall threshold: {best_recall_threshold}, Recall: {recalls[best_recall_index]}")

# Plot metrics as functions of the threshold
plt.figure(figsize=(12, 8))
plt.plot(thresholds, precisions, label='Precision', color='g')
plt.plot(thresholds, recalls, label='Recall', color='b')
plt.plot(thresholds, f1_scores, label='F1 Score', color='r')

# Add best threshold indicators
plt.axvline(x=best_f1_threshold, color='r', linestyle='--', label=f'Best F1 Threshold: {best_f1_threshold:.2f}')
plt.axvline(x=best_precision_threshold, color='g', linestyle='--', label=f'Best Precision Threshold: {best_precision_threshold:.2f}')
plt.axvline(x=best_recall_threshold, color='b', linestyle='--', label=f'Best Recall Threshold: {best_recall_threshold:.2f}')

plt.xlabel('Threshold')
plt.ylabel('Metric Value')
plt.title('Evaluation Metrics as Functions of Threshold (Energy-Based OOD Detection)')
plt.legend()
plt.show()

PYTHON

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix, ConfusionMatrixDisplay

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

def evaluate_ood_detection(id_scores, ood_scores, id_true_labels, id_predictions, ood_predictions, score_type='energy'):
    """
    Evaluate OOD detection based on either energy scores or softmax scores.

    Parameters:
    - id_scores: np.array, scores for in-distribution (ID) data
    - ood_scores: np.array, scores for out-of-distribution (OOD) data
    - id_true_labels: np.array, true labels for ID data
    - id_predictions: np.array, predicted labels for ID data
    - ood_predictions: np.array, predicted labels for OOD data
    - score_type: str, type of score used ('energy' or 'softmax')

    Returns:
    - Best thresholds for F1, Precision, and Recall
    - Plots of Precision, Recall, and F1 Score as functions of the threshold
    """
    # Define thresholds to evaluate
    if score_type == 'softmax':
        thresholds = np.linspace(0.5, 1.0, 200)
    else:
        thresholds = np.linspace(id_scores.min(), id_scores.max(), 50)

    # Store evaluation metrics for each threshold
    accuracies = []
    precisions = []
    recalls = []
    f1_scores = []

    # True labels for OOD data (since they are not part of the original labels)
    ood_true_labels = np.full(len(ood_scores), -1)

    for threshold in thresholds:
        # Classify OOD examples based on scores
        if score_type == 'energy':
            ood_classifications = np.where(ood_scores >= threshold, -1, ood_predictions)
            id_classifications = np.where(id_scores >= threshold, -1, id_predictions)
        elif score_type == 'softmax':
            ood_classifications = np.where(ood_scores <= threshold, -1, ood_predictions)
            id_classifications = np.where(id_scores <= threshold, -1, id_predictions)
        else:
            raise ValueError("Invalid score_type. Use 'energy' or 'softmax'.")

        # Combine OOD and ID classifications and true labels
        all_predictions = np.concatenate([ood_classifications, id_classifications])
        all_true_labels = np.concatenate([ood_true_labels, id_true_labels])

        # Evaluate metrics
        precision, recall, f1, _ = precision_recall_fscore_support(all_true_labels, all_predictions, labels=[-1, 0], average='macro', zero_division=0)
        accuracy = accuracy_score(all_true_labels, all_predictions)

        accuracies.append(accuracy)
        precisions.append(precision)
        recalls.append(recall)
        f1_scores.append(f1)

    # Find the best thresholds for each metric
    best_f1_index = np.argmax(f1_scores)
    best_f1_threshold = thresholds[best_f1_index]

    best_precision_index = np.argmax(precisions)
    best_precision_threshold = thresholds[best_precision_index]

    best_recall_index = np.argmax(recalls)
    best_recall_threshold = thresholds[best_recall_index]

    print(f"Best F1 threshold: {best_f1_threshold}, F1 Score: {f1_scores[best_f1_index]}")
    print(f"Best Precision threshold: {best_precision_threshold}, Precision: {precisions[best_precision_index]}")
    print(f"Best Recall threshold: {best_recall_threshold}, Recall: {recalls[best_recall_index]}")

    # Plot metrics as functions of the threshold
    plt.figure(figsize=(12, 8))
    plt.plot(thresholds, precisions, label='Precision', color='g')
    plt.plot(thresholds, recalls, label='Recall', color='b')
    plt.plot(thresholds, f1_scores, label='F1 Score', color='r')

    # Add best threshold indicators
    plt.axvline(x=best_f1_threshold, color='r', linestyle='--', label=f'Best F1 Threshold: {best_f1_threshold:.2f}')
    plt.axvline(x=best_precision_threshold, color='g', linestyle='--', label=f'Best Precision Threshold: {best_precision_threshold:.2f}')
    plt.axvline(x=best_recall_threshold, color='b', linestyle='--', label=f'Best Recall Threshold: {best_recall_threshold:.2f}')

    plt.xlabel('Threshold')
    plt.ylabel('Metric Value')
    plt.title(f'Evaluation Metrics as Functions of Threshold ({score_type.capitalize()}-Based OOD Detection)')
    plt.legend()
    plt.show()

    # plot confusion matrix

    # Threshold value for the energy score
    upper_threshold = best_f1_threshold  # Using the best F1 threshold from the previous calculation

    # Classifying OOD examples based on energy scores
    ood_classifications = np.where(ood_energy_scores >= upper_threshold, -1,  # classified as OOD
                                  np.where(ood_energy_scores < upper_threshold, 0, -1))  # classified as ID

    # Classifying ID examples based on energy scores
    id_classifications = np.where(id_energy_scores >= upper_threshold, -1,  # classified as OOD
                                  np.where(id_energy_scores < upper_threshold, id_true_labels, -1))  # classified as ID

    # Combine OOD and ID classifications and true labels
    all_predictions = np.concatenate([ood_classifications, id_classifications])
    all_true_labels = np.concatenate([ood_true_labels, id_true_labels])

    # Confusion matrix
    cm = confusion_matrix(all_true_labels, all_predictions, labels=[0, 1, -1])

    # Plotting the confusion matrix
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Shirt", "Pants", "OOD"])
    disp.plot(cmap=plt.cm.Blues)
    plt.title('Confusion Matrix for OOD and ID Classification (Energy-Based)')
    plt.show()


    return best_f1_threshold, best_precision_threshold, best_recall_threshold

# Example usage
# Assuming id_energy_scores, ood_energy_scores, id_true_labels, and test_labels are already defined
best_f1_threshold, best_precision_threshold, best_recall_threshold = evaluate_ood_detection(id_energy_scores, ood_energy_scores, id_true_labels, test_labels, score_type='energy')
best_f1_threshold, best_precision_threshold, best_recall_threshold = evaluate_ood_detection(id_softmax_scores[:,0], ood_softmax_scores[:,0], id_true_labels, test_labels, score_type='softmax')

PYTHON

ood_softmax_scores[:,0].shape

PYTHON

PYTHON

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Threshold value for the energy score
upper_threshold = best_f1_threshold  # Using the best F1 threshold from the previous calculation

# Classifying OOD examples based on energy scores
ood_classifications = np.where(ood_energy_scores >= upper_threshold, -1,  # classified as OOD
                               np.where(ood_energy_scores < upper_threshold, 0, -1))  # classified as ID

# Classifying ID examples based on energy scores
id_classifications = np.where(id_energy_scores >= upper_threshold, -1,  # classified as OOD
                              np.where(id_energy_scores < upper_threshold, id_true_labels, -1))  # classified as ID

# Combine OOD and ID classifications and true labels
all_predictions = np.concatenate([ood_classifications, id_classifications])
all_true_labels = np.concatenate([ood_true_labels, id_true_labels])

# Confusion matrix
cm = confusion_matrix(all_true_labels, all_predictions, labels=[0, 1, -1])

# Plotting the confusion matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Shirt", "Pants", "OOD"])
disp.plot(cmap=plt.cm.Blues)
plt.title('Confusion Matrix for OOD and ID Classification (Energy-Based)')
plt.show()

Limitations of our approach thus far

  • Focus on single OOD class: More reliable/accurate thresholds can/should be obtained using a wider variety (more classes) and larger sample of OOD data. This is part of the challenge of OOD detection which is that space of OOD data is vast. Possible exercise: Redo thresholding using all remaining classes in dataset.

References and supplemental resources

Content from OOD detection: distance-based and contrastive learning


Last updated on 2024-08-14 | Edit this page

Overview

Questions

  • How do distance-based methods like Mahalanobis distance and KNN work for OOD detection?
  • What is contrastive learning and how does it improve feature representations?
  • How does contrastive learning enhance the effectiveness of distance-based OOD detection methods?

Objectives

  • Gain a thorough understanding of distance-based OOD detection methods, including Mahalanobis distance and KNN.
  • Learn the principles of contrastive learning and its role in improving feature representations.
  • Explore the synergy between contrastive learning and distance-based OOD detection methods to enhance detection performance.

Example 3: Distance-Based Methods

Lee et al., A simple unified framework for detecting out-of-distribution samples and adversarial attacks. NeurIPS 2018.

With softmax and energy-based methods, we focus on the models outputs to determine a threshold that defines ID and OOD data. With distance-based methods, we focus on the feature representations learned by the model.

In the case of neural networks, a common approach is to use the penultimate layer as a feature representation that can define an ID clusters for each class. You can then use distance to the closesent centroid as a proxy for OOD measure.

Mahalanobis distance (parametric)

Model the feature space as a mixture of multivariate Gaussian distribution, one for each class. use distance to the closest centroid as proxy for OOD measure

Limiations of parametric approach

  • Strong distributional assumption (features may not necessarily be Gassian-distributed)
  • Suboptimal embedding

Nearest Neighbor Distance (non-parametric)

Sun et al., Out-of-distribution Detection with Deep Nearest Neighbors, ICML 2022

  • Sample considered OOD if it has a large KNN distrance w.r.t. training data (and vice versa)
  • No distributional assumptions about underlying embedding space. Stronger generality and flexibility than mahalanobis distancew

CIDER

This one might be out of scope…

Ming et al., How to Exploit Hyperspherical Embeddings for Out-of-Distribution Detection # Contrastive Learning

  • Explain the basic idea of contrastive learning: learning representations by contrasting positive and negative pairs.
  • Highlight the role of contrastive learning in learning discriminative features that can separate in-distribution (ID) from OOD data more effectively.
  • Illustrate how contrastive learning improves the feature space, making distance-based methods (like Mahalanobis and KNN) more effective.
  • Provide examples or case studies where contrastive learning has been applied to enhance OOD detection. # Example X: Comparing feature representations with and without contrastive learning

Returning to UMAP

Notice how in our UMAP visualization, we say three distinct clusters representing each class. However, our model still confidently rated many sandals as being tshirts. The crux of this issue is that models do not know what they don’t know. They simply draw classifcation boundaries between the classes available to them during training.

One way to get around this problem is to train models to learn discriminative features…

Contrastive learning

In this experiment, we use both a traditional neural network and a contrastive learning model to classify images from the Fashion MNIST dataset, focusing on T-shirts (class 0) and Trousers (class 1). Additionally, we evaluate the models on out-of-distribution (OOD) data, specifically Sandals (class 5). To visualize the models’ learned features, we extract features from specific layers of the neural networks and reduce their dimensionality using UMAP.

Overview of steps

1) Train model

  • With or without contrastive learning
  • Focusing on T-shirts (class 0) and Trousers (class 1)
  • Additionally, we evaluate the models on out-of-distribution (OOD) data, specifically Sandals (class 5)

2) Feature Extraction:

  • After training, we set the models to evaluation mode to prevent updates to the model parameters.
  • For each subset of the data (training, validation, and OOD), we pass the images through the entire network up to the first fully connected layer.
  • The output of this layer, which captures high-level features and abstractions, is then used as a 1D feature vector.
  • These feature vectors are detached from the computational graph and converted to NumPy arrays for further processing.

3) Dimensionality Reduction and Visualization:

  • We combine the feature vectors from the training, validation, and OOD data into a single dataset.
  • UMAP (Uniform Manifold Approximation and Projection) is used to reduce the dimensionality of the feature vectors from the high-dimensional space to 2D, making it possible to visualize the relationships between different data points.
  • The reduced features are then plotted, with different colors representing the training data (T-shirts and Trousers), validation data (T-shirts and Trousers), and OOD data (Sandals).

By visualizing the features generated from different subsets of the data, we can observe how well the models have learned to distinguish between in-distribution classes (T-shirts and Trousers) and handle OOD data (Sandals). This approach allows us to evaluate the robustness and generalization capabilities of the models in dealing with data that may not have been seen during training. ## Standard neural network w/out contrastive learning

1) Train model

We’ll first train our vanilla CNN w/out contrastive learning.

  • Focusing on T-shirts (class 0) and Trousers (class 1)
  • Additionally, we evaluate the models on out-of-distribution (OOD) data, specifically Sandals (class 5)

PYTHON

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, Dataset

# Check if GPU is available and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

PYTHON

# Define a simple CNN model for classification
class ClassificationModel(nn.Module):
    def __init__(self):
        super(ClassificationModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(32 * 28 * 28, 128)
        self.fc2 = nn.Linear(128, 2)  # 2 classes for T-shirts and Trousers

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Load Fashion MNIST dataset and filter for T-shirts and Trousers
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

train_indices = np.where((train_dataset.targets == 0) | (train_dataset.targets == 1))[0]
val_indices = np.where((test_dataset.targets == 0) | (test_dataset.targets == 1))[0]
ood_indices = np.where(test_dataset.targets == 5)[0]

# Use a subset of the data for quicker training
train_subset = Subset(train_dataset, np.random.choice(train_indices, size=5000, replace=False))
val_subset = Subset(test_dataset, np.random.choice(val_indices, size=1000, replace=False))
ood_subset = Subset(test_dataset, np.random.choice(ood_indices, size=1000, replace=False))

train_loader = DataLoader(train_subset, batch_size=256, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=256, shuffle=False)
ood_loader = DataLoader(ood_subset, batch_size=256, shuffle=False)

# Initialize the model and move it to the device
classification_model = ClassificationModel().to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classification_model.parameters(), lr=0.001)

# Training loop for standard neural network
train_losses = []
val_losses = []

for epoch in range(n_epochs):
    total_train_loss = 0
    classification_model.train()
    for batch_images, batch_labels in train_loader:
        batch_images, batch_labels = batch_images.to(device), batch_labels.to(device)

        optimizer.zero_grad()
        outputs = classification_model(batch_images)
        loss = criterion(outputs, batch_labels)

        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

    total_val_loss = 0
    classification_model.eval()
    with torch.no_grad():
        for batch_images, batch_labels in val_loader:
            batch_images, batch_labels = batch_images.to(device), batch_labels.to(device)
            outputs = classification_model(batch_images)
            loss = criterion(outputs, batch_labels)
            total_val_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)
    avg_val_loss = total_val_loss / len(val_loader)
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)

    print(f'Epoch {epoch + 1}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

PYTHON

# Plot training and validation loss
plt.figure(figsize=(10, 6))
plt.plot(range(1, n_epochs + 1), train_losses, label='Train Loss')
plt.plot(range(1, n_epochs + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss - Classification Model')
plt.legend()
plt.show()

2) Extracting learned features

  • After training, we set the models to evaluation mode to prevent updates to the model parameters.
  • For each subset of the data (training, validation, and OOD), we pass the images through the entire network up to the first fully connected layer.
  • The output of this layer, which captures high-level features and abstractions, is then used as a 1D feature vector.
  • These feature vectors are detached from the computational graph and converted to NumPy arrays for further processing.

Why later layer features are better

In both the traditional neural network and the contrastive learning model, we will extract features from the first fully connected layer (fc1) before the final classification layer. Here’s why this layer is particularly suitable for feature extraction:

  • Hierarchical feature representation: In neural networks, the initial layers typically capture low-level features such as edges, textures, and simple shapes (e.g., with CNNs). As you move deeper into the network, the layers capture higher-level, more abstract features that are more relevant for the final classification task. These high-level features are combinations of the low-level features and are typically more discriminative.

  • Better separation of classes: Features from later layers have been transformed through several layers of non-linear activations and pooling operations, making them more suitable for distinguishing between classes. These features are usually more compact and have a better separation in the feature space, which helps in visualization and understanding the model’s decision-making process.

PYTHON

# Extract features using the trained classification model
classification_model.eval()
train_features = []
train_labels_list = []
for batch_images, batch_labels in train_loader:
    batch_images = batch_images.to(device)
    features = classification_model.fc1(classification_model.flatten(classification_model.conv1(batch_images)))
    train_features.append(features.detach().cpu().numpy())
    train_labels_list.append(batch_labels.numpy())

val_features = []
val_labels_list = []
for batch_images, batch_labels in val_loader:
    batch_images = batch_images.to(device)
    features = classification_model.fc1(classification_model.flatten(classification_model.conv1(batch_images)))
    val_features.append(features.detach().cpu().numpy())
    val_labels_list.append(batch_labels.numpy())

ood_features = []
ood_labels_list = []
for batch_images, batch_labels in ood_loader:
    batch_images = batch_images.to(device)
    features = classification_model.fc1(classification_model.flatten(classification_model.conv1(batch_images)))
    ood_features.append(features.detach().cpu().numpy())
    ood_labels_list.append(batch_labels.numpy())

3) Dimensionality Reduction and Visualization:

  • We combine the feature vectors from the training, validation, and OOD data into a single dataset.
  • UMAP (Uniform Manifold Approximation and Projection) is used to reduce the dimensionality of the feature vectors from the high-dimensional space to 2D, making it possible to visualize the relationships between different data points.
  • The reduced features are then plotted, with different colors representing the training data (T-shirts and Trousers), validation data (T-shirts and Trousers), and OOD data (Sandals).

PYTHON

train_features = np.concatenate(train_features)
train_labels = np.concatenate(train_labels_list)
val_features = np.concatenate(val_features)
val_labels = np.concatenate(val_labels_list)
ood_features = np.concatenate(ood_features)
ood_labels = np.concatenate(ood_labels_list)

# Perform UMAP to visualize the classification model features
combined_features = np.vstack([train_features, val_features, ood_features])
combined_labels = np.hstack([train_labels, val_labels, np.full(len(ood_labels), 2)])  # Use 2 for OOD class

umap_reducer = umap.UMAP(n_components=2, random_state=42)
umap_results = umap_reducer.fit_transform(combined_features)

# Split the results back into train, val, and OOD data
umap_train_features = umap_results[:len(train_features)]
umap_val_features = umap_results[len(train_features):len(train_features) + len(val_features)]
umap_ood_features = umap_results[len(train_features) + len(val_features):]

PYTHON

# Plotting UMAP components for classification model
alpha = .2
plt.figure(figsize=(10, 6))
# Plot train T-shirts
scatter1 = plt.scatter(umap_train_features[train_labels == 0, 0], umap_train_features[train_labels == 0, 1], c='blue', alpha=alpha, label='Train T-shirts (ID)')
# Plot train Trousers
scatter2 = plt.scatter(umap_train_features[train_labels == 1, 0], umap_train_features[train_labels == 1, 1], c='red', alpha=alpha, label='Train Trousers (ID)')
# Plot val T-shirts
scatter3 = plt.scatter(umap_val_features[val_labels == 0, 0], umap_val_features[val_labels == 0, 1], c='blue', alpha=alpha, marker='x', label='Val T-shirts (ID)')
# Plot val Trousers
scatter4 = plt.scatter(umap_val_features[val_labels == 1, 0], umap_val_features[val_labels == 1, 1], c='red', alpha=alpha, marker='x', label='Val Trousers (ID)')
# Plot OOD Sandals
scatter5 = plt.scatter(umap_ood_features[:, 0], umap_ood_features[:, 1], c='green', alpha=alpha, marker='o', label='OOD Sandals')
plt.legend(handles=[scatter1, scatter2, scatter3, scatter4, scatter5])
plt.xlabel('First UMAP Component')
plt.ylabel('Second UMAP Component')
plt.title('UMAP of Classification Model Features')
plt.show()

Neural network trained with contrastive learning

What is Contrastive Learning?

Contrastive learning is a technique where the model learns to distinguish between similar and dissimilar pairs of data. This can be achieved through different types of learning: supervised, unsupervised, and self-supervised.

  • Supervised Contrastive Learning: Uses labeled data to create pairs or groups of similar and dissimilar data points based on their labels.

  • Unsupervised Contrastive Learning: Does not use any labels. Instead, it relies on inherent patterns in the data to create pairs. For example, random pairs of data points might be assumed to be dissimilar, while augmented versions of the same data point might be assumed to be similar.

  • Self-Supervised Contrastive Learning: A form of unsupervised learning where the model generates its own supervisory signal from the data. This typically involves data augmentation techniques where positive pairs are created by augmenting the same image (e.g., cropping, rotating), and negative pairs are formed from different images.

In contrastive learning, the model learns to bring similar pairs closer in the embedding space while pushing dissimilar pairs further apart. This approach is particularly useful for tasks like image retrieval, clustering, and representation learning.

Certainly! Let’s expand on how we are treating the T-shirt, Trouser, and Sandals classes in the context of our supervised contrastive learning framework.

Key Concepts in Our Code

Data Preparation

  • Dataset: We use the Fashion MNIST dataset, which contains images of various clothing items, each labeled with a specific class.
  • Class Filtering: For this exercise, we are focusing on three classes from the Fashion MNIST dataset:
    • T-shirts (class label 0)
    • Trousers (class label 1)
    • Sandals (class label 5)
  • In-Distribution (ID) Data: We treat T-shirts and Trousers as our primary classes for training. These are considered “in-distribution” data.
  • Out-of-Distribution (OOD) Data: Sandals are treated as a different class for testing the robustness of our learned embeddings, making them “out-of-distribution” data.

Pairs Creation

For each image in our training set: - Positive Pair: We find another image of the same class (either T-shirt or Trouser). These pairs are labeled as similar. - Negative Pair: We randomly choose an image from a different class (T-shirt paired with Trouser or vice versa). These pairs are labeled as dissimilar.

By creating these pairs, the model learns to produce embeddings where similar images (same class) are close together, and dissimilar images (different classes) are farther apart.

Model Architecture

The model is a simple Convolutional Neural Network (CNN) designed to output embeddings. It consists of: - Two convolutional layers to extract features from the images. - Fully connected layers to map these features to a 50-dimensional embedding space.

Training Process

  • Forward Pass: The model processes pairs of images and outputs their embeddings.
  • Contrastive Loss: We use a contrastive loss function to train the model. This loss encourages embeddings of similar pairs to be close and embeddings of dissimilar pairs to be far apart. Specifically, we:
    • Normalize the embeddings.
    • Calculate similarity scores.
    • Compute the contrastive loss, which penalizes similar pairs if they are not close enough and dissimilar pairs if they are too close.

Differences from Standard Neural Network Training

  • Data Pairing: In contrastive learning, we create pairs of data points. Standard neural network training typically involves individual data points with corresponding labels.
  • Loss Function: We use a contrastive loss function instead of the typical cross-entropy loss used in classification tasks. The contrastive loss is designed to optimize the relative distances between pairs of embeddings.
  • Supervised Learning: Our approach uses labeled data to form similar and dissimilar pairs, making it supervised contrastive learning. This contrasts with self-supervised or unsupervised methods where labels are not used.

Specific Type of Contrastive Learning

The specific contrastive learning technique we are using here is a form of supervised contrastive learning. This involves using labeled data to create similar and dissimilar pairs of images. The model is trained to output embeddings where a contrastive loss function is applied to these pairs. By doing so, the model learns to map images into an embedding space where similar images are close together, and dissimilar images are farther apart.

By training with this method, the model learns robust feature representations that are useful for various downstream tasks, even with limited labeled data. This is powerful because it allows leveraging labeled data to improve the model’s performance and generalizability.

Application of the Framework

  1. Training with In-Distribution Data:
    • T-shirts and Trousers: These classes are used to train the model. Positive and negative pairs are created within this subset to teach the model to distinguish between the two classes.
  2. Testing with Out-of-Distribution Data:
    • Sandals: This class is used to test the robustness of the embeddings learned by the model. By introducing a completely different class during testing, we can evaluate how well the model generalizes to new, unseen data.

This framework demonstrates how supervised contrastive learning can be effectively applied to learn discriminative embeddings that can generalize well to both in-distribution and out-of-distribution data.

PYTHON

import torch
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import umap
import torch.nn as nn
import torch.optim as optim

class PairDataset(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img1 = self.images[idx]
        label1 = self.labels[idx]
        idx2 = np.random.choice(np.where(self.labels == label1)[0])
        img2 = self.images[idx2]
        return img1, img2, label1

# Load Fashion MNIST dataset and filter for T-shirts and Trousers
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

train_indices = np.where((train_dataset.targets == 0) | (train_dataset.targets == 1))[0]
val_indices = np.where((test_dataset.targets == 0) | (test_dataset.targets == 1))[0]
ood_indices = np.where(test_dataset.targets == 5)[0]

# Use a subset of the data for quicker training
train_subset = Subset(train_dataset, np.random.choice(train_indices, size=5000, replace=False))
val_subset = Subset(test_dataset, np.random.choice(val_indices, size=1000, replace=False))
ood_subset = Subset(test_dataset, np.random.choice(ood_indices, size=1000, replace=False))

# Create DataLoaders for the subsets
train_images = np.array([train_dataset[i][0].numpy() for i in train_indices])
train_labels = np.array([train_dataset[i][1] for i in train_indices])
val_images = np.array([test_dataset[i][0].numpy() for i in val_indices])
val_labels = np.array([test_dataset[i][1] for i in val_indices])
ood_images = np.array([test_dataset[i][0].numpy() for i in ood_indices])
ood_labels = np.array([test_dataset[i][1] for i in ood_indices])

train_loader = DataLoader(PairDataset(train_images, train_labels), batch_size=256, shuffle=True)
val_loader = DataLoader(PairDataset(val_images, val_labels), batch_size=256, shuffle=False)
ood_loader = DataLoader(PairDataset(ood_images, ood_labels), batch_size=256, shuffle=False)

# Inspect the data loaders
for batch_images1, batch_images2, batch_labels in train_loader:
    print(f"train_loader batch_images1 shape: {batch_images1.shape}")
    print(f"train_loader batch_images2 shape: {batch_images2.shape}")
    print(f"train_loader batch_labels shape: {batch_labels.shape}")
    break

for batch_images1, batch_images2, batch_labels in val_loader:
    print(f"val_loader batch_images1 shape: {batch_images1.shape}")
    print(f"val_loader batch_images2 shape: {batch_images2.shape}")
    print(f"val_loader batch_labels shape: {batch_labels.shape}")
    break

for batch_images1, batch_images2, batch_labels in ood_loader:
    print(f"ood_loader batch_images1 shape: {batch_images1.shape}")
    print(f"ood_loader batch_images2 shape: {batch_images2.shape}")
    print(f"ood_loader batch_labels shape: {batch_labels.shape}")
    break

PYTHON

# Define a simple CNN model for contrastive learning
class ContrastiveModel(nn.Module):
    def __init__(self):
        super(ContrastiveModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(32 * 28 * 28, 128)
        self.fc2 = nn.Linear(128, 50)  # Embedding size

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Define contrastive loss function
def contrastive_loss(z_i, z_j, temperature=0.5):
    z_i = nn.functional.normalize(z_i, dim=1)
    z_j = nn.functional.normalize(z_j, dim=1)
    batch_size = z_i.size(0)
    z = torch.cat([z_i, z_j], dim=0)

    sim = torch.mm(z, z.t()) / temperature
    sim_i_j = torch.diag(sim, batch_size)
    sim_j_i = torch.diag(sim, -batch_size)

    positives = torch.cat([sim_i_j, sim_j_i], dim=0)
    negatives_mask = ~torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
    negatives = sim[negatives_mask].view(2 * batch_size, -1)

    loss = -torch.mean(positives) + torch.mean(negatives)
    return loss

# Training loop for contrastive learning
def train_contrastive_model(model, train_loader, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for img1, img2, _ in train_loader:
            img1, img2 = img1.to(device), img2.to(device)

            optimizer.zero_grad()

            z_i = model(img1)
            z_j = model(img2)

            loss = contrastive_loss(z_i, z_j)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")

# Instantiate the model, optimizer, and start training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
contrastive_model = ContrastiveModel().to(device)
optimizer = optim.Adam(contrastive_model.parameters(), lr=0.001)

train_contrastive_model(contrastive_model, train_loader, optimizer, num_epochs=n_epochs)

2) Extracting learned features

  • After training, we set the models to evaluation mode to prevent updates to the model parameters.
  • For each subset of the data (training, validation, and OOD), we pass the images through the entire network up to the first fully connected layer.
  • The output of this layer, which captures high-level features and abstractions, is then used as a 1D feature vector.
  • These feature vectors are detached from the computational graph and converted to NumPy arrays for further processing.

PYTHON

# Extract features using the trained contrastive model
contrastive_model.eval()
train_features = []
train_labels_list = []
for img1, _, label1 in train_loader:
    img1 = img1.to(device)
    features = contrastive_model.fc1(contrastive_model.flatten(contrastive_model.conv1(img1)))
    train_features.append(features.detach().cpu().numpy())
    train_labels_list.append(label1.numpy())

val_features = []
val_labels_list = []
for img1, _, label1 in val_loader:
    img1 = img1.to(device)
    features = contrastive_model.fc1(contrastive_model.flatten(contrastive_model.conv1(img1)))
    val_features.append(features.detach().cpu().numpy())
    val_labels_list.append(label1.numpy())

ood_features = []
ood_labels_list = []
for img1, _, label1 in ood_loader:
    img1 = img1.to(device)
    features = contrastive_model.fc1(contrastive_model.flatten(contrastive_model.conv1(img1)))
    ood_features.append(features.detach().cpu().numpy())
    ood_labels_list.append(label1.numpy())

train_features = np.concatenate(train_features)
train_labels = np.concatenate(train_labels_list)
val_features = np.concatenate(val_features)
val_labels = np.concatenate(val_labels_list)
ood_features = np.concatenate(ood_features)
ood_labels = np.concatenate(ood_labels_list)

# Diagnostic print statements
print(f"train_features shape: {train_features.shape}")
print(f"train_labels shape: {train_labels.shape}")
print(f"val_features shape: {val_features.shape}")
print(f"val_labels shape: {val_labels.shape}")
print(f"ood_features shape: {ood_features.shape}")
print(f"ood_labels shape: {ood_labels.shape}")

3) Dimensionality Reduction and Visualization:

  • We combine the feature vectors from the training, validation, and OOD data into a single dataset.
  • UMAP (Uniform Manifold Approximation and Projection) is used to reduce the dimensionality of the feature vectors from the high-dimensional space to 2D, making it possible to visualize the relationships between different data points.
  • The reduced features are then plotted, with different colors representing the training data (T-shirts and Trousers), validation data (T-shirts and Trousers), and OOD data (Sandals).

PYTHON

# Ensure the labels array for OOD matches the feature array length
combined_features = np.vstack([train_features, val_features, ood_features])
combined_labels = np.hstack([train_labels, val_labels, np.full(len(ood_features), 2)])  # Use 2 for OOD class

umap_reducer = umap.UMAP(n_components=2, random_state=42)
umap_results = umap_reducer.fit_transform(combined_features)

# Split the results back into train, val, and OOD data
umap_train_features = umap_results[:len(train_features)]
umap_val_features = umap_results[len(train_features):len(train_features) + len(val_features)]
umap_ood_features = umap_results[len(train_features) + len(val_features):]

# Plotting UMAP components for contrastive learning model
plt.figure(figsize=(10, 6))
# Plot train T-shirts
scatter1 = plt.scatter(umap_train_features[train_labels == 0, 0], umap_train_features[train_labels == 0, 1], c='blue', alpha=0.5, label='Train T-shirts (ID)')
# Plot train Trousers
scatter2 = plt.scatter(umap_train_features[train_labels == 1, 0], umap_train_features[train_labels == 1, 1], c='red', alpha=0.5, label='Train Trousers (ID)')
# Plot val T-shirts
scatter3 = plt.scatter(umap_val_features[val_labels == 0, 0], umap_val_features[val_labels == 0, 1], c='blue', alpha=0.5, marker='x', label='Val T-shirts (ID)')
# Plot val Trousers
scatter4 = plt.scatter(umap_val_features[val_labels == 1, 0], umap_val_features[val_labels == 1, 1], c='red', alpha=0.5, marker='x', label='Val Trousers (ID)')
# Plot OOD Sandals
scatter5 = plt.scatter(umap_ood_features[:, 0], umap_ood_features[:, 1], c='green', alpha=0.5, marker='o', label='OOD Sandals')
plt.legend(handles=[scatter1, scatter2, scatter3, scatter4, scatter5])
plt.xlabel('First UMAP Component')
plt.ylabel('Second UMAP Component')
plt.title('UMAP of Contrastive Model Features')
plt.show()

Limitations of Threshold-Based OOD Detection Methods

Threshold-based out-of-distribution (OOD) detection methods are widely used due to their simplicity and intuitive nature. However, they come with several significant limitations that need to be considered:

  1. Dependence on OOD Data Choice:
    • Variety and Representation: The effectiveness of threshold-based methods heavily relies on the variety and representativeness of the OOD data used during threshold selection. If the chosen OOD samples do not adequately cover the possible range of OOD scenarios, the threshold may not generalize well to unseen OOD data.
    • Threshold Determination: To determine a robust threshold, it is essential to include a diverse set of OOD samples. This helps in setting a threshold that can effectively distinguish between in-distribution and out-of-distribution data across various scenarios. Without a comprehensive OOD dataset, the threshold might either be too conservative, causing many ID samples to be misclassified as OOD, or too lenient, failing to detect OOD samples accurately.
  2. Impact of High Thresholds:
    • False OOD Classification: High thresholds can lead to a significant number of ID samples being incorrectly classified as OOD. This false OOD classification results in the loss of potentially valuable data, reducing the efficiency and performance of the model.
    • Data Efficiency: In applications where retaining as much ID data as possible is crucial, high thresholds can be particularly detrimental. It’s important to strike a balance between detecting OOD samples and retaining ID samples to ensure the model’s overall performance and data efficiency.
  3. Sensitivity to Model Confidence:
    • Model Calibration: Threshold-based methods rely on the model’s confidence scores, which can be misleading if the model is poorly calibrated. Overconfident predictions for ID samples or underconfident predictions for OOD samples can result in suboptimal threshold settings.
    • Confidence Variability: The variability in confidence scores across different models and architectures can make it challenging to set a universal threshold. Each model might require different threshold settings, complicating the deployment and maintenance of threshold-based OOD detection systems.
  4. Lack of Discriminative Features:
    • Boundary-Based Detection: Threshold-based methods focus on class boundaries rather than learning discriminative features that can effectively separate ID and OOD samples. This approach can be less robust, particularly in complex or high-dimensional data spaces where class boundaries might be less clear.
    • Feature Learning: By relying solely on confidence scores, these methods miss the opportunity to learn and leverage features that are inherently more discriminative. This limitation highlights the need for advanced techniques like contrastive learning, which focuses on learning features that distinguish between ID and OOD samples more effectively.

Conclusion

While threshold-based OOD detection methods offer a straightforward approach, their limitations underscore the importance of considering additional OOD samples for robust threshold determination and the potential pitfalls of high thresholds. Transitioning to methods that learn discriminative features rather than relying solely on class boundaries can address these limitations, paving the way for more effective OOD detection. This sets the stage for discussing contrastive learning, which provides a powerful framework for learning such discriminative features.

Content from OOD detection: training-time regularization


Last updated on 2024-08-14 | Edit this page

Overview

Questions

  • What are the key considerations when designing algorithms for OOD detection?
  • How can OOD detection be incorporated into the loss functions of models?
  • What are the challenges and best practices for training models with OOD detection capabilities?

Objectives

  • Understand the critical design considerations for creating effective OOD detection algorithms.
  • Learn how to integrate OOD detection into the loss functions of machine learning models.
  • Identify the challenges in training models with OOD detection and explore best practices to overcome these challenges.

Training-time regularization for OOD detection

Content from Documenting and releasing a model


Last updated on 2024-07-16 | Edit this page

Overview

Questions

  • Why is model sharing important in the context of reproducibility and responsible use?
  • What are the challenges, risks, and ethical considerations related to sharing models?
  • How can model-sharing best practices be applied using tools like model cards and the Hugging Face platform?
  • What is distribution shift and what are its implications in machine learning models?

Objectives

  • Understand the importance of model sharing and best practices to ensure reproducibility and responsible use of models.
  • Understand the challenges, risks, and ethical concerns associated with model sharing.
  • Apply model-sharing best practices through using model cards and the Hugging Face platform.

Key Points

  • Model cards are the standard technique for communicating information about how machine learning systems were trained and how they should and should not be used.
  • Models can be shared and reused via the Hugging Face platform.

Why should we share trained models?

Discuss in small groups and report out: Why do you believe it is or isn’t important to share ML models? How has model-sharing contributed to your experiences or projects?

  • Accelerating research: Sharing models allows researchers and practitioners to build upon existing work, accelerating the pace of innovation in the field.
  • Knowledge exchange: Model sharing promotes knowledge exchange and collaboration within the machine learning community, fostering a culture of open science.
  • Reproducibility: Sharing models, along with associated code and data, enhances reproducibility, enabling others to validate and verify the results reported in research papers.
  • Benchmarking: Shared models serve as benchmarks for comparing new models and algorithms, facilitating the evaluation and improvement of state-of-the-art techniques.
  • Education / Accessibility to state-of-the-art architectures: Shared models provide valuable resources for educational purposes, allowing students and learners to explore and experiment with advanced machine learning techniques.
  • Repurpose (transfer learning and finetuning): Some models (i.e., foundation models) can be repurposed for a wide variety of tasks. This is especially useful when working with limited data. Data scarcity
  • Resource efficiency: Instead of training a model from the ground up, practitioners can use existing models as a starting point, saving time, computational resources, and energy.

Challenges and risks of model sharing

Discuss in small groups and report out: What are some potential challenges, risks, or ethical concerns associated with model sharing and reproducing ML workflows?

  • Privacy concerns: Sharing models that were trained on sensitive or private data raises privacy concerns. The potential disclosure of personal information through the model poses a risk to individuals and can lead to unintended consequences.
  • Informed consent: If models involve user data, ensuring informed consent is crucial. Sharing models trained on user-generated content without clear consent may violate privacy norms and regulations.
  • Data bias and fairness: Models trained on biased datasets may perpetuate or exacerbate existing biases. Reproducing workflows without addressing bias in the data may result in unfair outcomes, particularly in applications like hiring or criminal justice.
  • Intellectual property: Models may be developed within organizations with proprietary data and methodologies. Sharing such models without proper consent or authorization may lead to intellectual property disputes and legal consequences.
  • Model robustness and generalization: Reproduced models may not generalize well to new datasets or real-world scenarios. Failure to account for the limitations of the original model can result in reduced performance and reliability in diverse settings.
  • Lack of reproducibility: Incomplete documentation, missing details, or changes in dependencies over time can hinder the reproducibility of ML workflows. This lack of reproducibility can impede scientific progress and validation of research findings.
  • Unintended use and misuse: Shared models may be used in unintended ways, leading to ethical concerns. Developers should consider the potential consequences of misuse, particularly in applications with societal impact, such as healthcare or law enforcement.
  • Responsible AI considerations: Ethical considerations, such as fairness, accountability, and transparency, should be addressed during model sharing. Failing to consider these aspects can result in models that inadvertently discriminate or lack interpretability. Models used for decision-making, especially in critical areas like healthcare or finance, should be ethically deployed. Transparent documentation and disclosure of how decisions are made are essential for responsible AI adoption.

Saving model locally


Let’s review the simplest method for sharing a model first — saving the model locally. When working with PyTorch, it’s important to know how to save and load models efficiently. This process ensures that you can pause your work, share your models, or deploy them for inference without having to retrain them from scratch each time.

Define model

As an example, we’ll configure a simple perceptron (single hidden layer) in PyTorch. We’ll define a bare bones class for this just so we can initialize the model.

PYTHON

from typing import Dict, Any
import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self, config: Dict[str, int]):
        super().__init__()
        # Parameter is a trainable tensor initialized with random values
        self.param = nn.Parameter(torch.rand(config["num_channels"], config["hidden_size"]))
        # Linear layer (fully connected layer) for the output
        self.linear = nn.Linear(config["hidden_size"], config["num_classes"])
        # Store the configuration
        self.config = config

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Forward pass: Add the input to the param tensor, then pass through the linear layer
        return self.linear(x + self.param)

Initialize model by calling the class with configuration settings.

PYTHON

# Create model instance with specific configuration
config = {"num_channels": 3, "hidden_size": 32, "num_classes": 10}
model = MyModel(config=config)

We can then write a function to save out the model. We’ll need both the model weights and the model’s configuration (hyperparameter settings). We’ll saving the configurations as a json since a key/value format is convenient here.

PYTHON

import json

# Function to save model and config locally
def save_model(model: nn.Module, model_path: str, config_path: str) -> None:
    # Save model state dict (weights and biases) as a .pth file
    torch.save(model.state_dict(), model_path) #
    # Save config
    with open(config_path, 'w') as f:
        json.dump(model.config, f)

PYTHON

# Save the model and config locally
save_model(model, "my_awesome_model.pth", "my_awesome_model_config.json")

To load the model back in, we can write another function

PYTHON

# Function to load model and config locally
def load_model(model_class: Any, model_path: str, config_path: str) -> nn.Module:
    # Load config
    with open(config_path, 'r') as f:
        config = json.load(f)
    # Create model instance with config
    model = model_class(config=config)
    # Load model state dict
    model.load_state_dict(torch.load(model_path))
    return model

PYTHON

# Load the model and config locally
loaded_model = load_model(MyModel, "my_awesome_model.pth", "my_awesome_model_config.json")

# Verify the loaded model
print(loaded_model)

Saving a model to Hugging Face


To share your model with a wider audience, we recommend uploading your model to Hugging Face. Hugging Face is a very popular machine learning (ML) platform and community that helps users build, deploy, share, and train machine learning models. It has quickly become the go-to option for sharing models with the public.

Create a Hugging Face account and access Token

If you haven’t completed these steps from the setup, make sure to do this now.

Create account: To create an account on Hugging Face, visit: huggingface.co/join. Enter an email address and password, and follow the instructions provided via Hugging Face (you may need to verify your email address) to complete the process.

Setup access token: Once you have your account created, you’ll need to generate an access token so that you can upload/share models to your Hugging Face account during the workshop. To generate a token, visit the Access Tokens setting page after logging in.

Login to Hugging Face account

To login, you will need to retrieve your access token from the Access Tokens setting page

PYTHON

!huggingface-cli login

You might get a message saying you cannot authenticate through git-credential as no helper is defined on your machine. TODO: What does this warning mean?

Once logged in, we will need to edit our model class defnition to include Hugging Face’s “push_to_hub” attribe. To enable the push_to_hub functionality, you’ll need to include the PyTorchModelHubMixin “mixin class” provided by the huggingface_hub library. A mixin class is a type of class used in object-oriented programming to “mix in” additional properties and methods into a class. The PyTorchModelHubMixin class adds methods to your PyTorch model to enable easy saving and loading from the Hugging Face Model Hub.

Here’s how you can adjust the code to incorporate both saving/loading locally and pushing the model to the Hugging Face Hub.

PYTHON

from huggingface_hub import PyTorchModelHubMixin # NEW

class MyModel(nn.Module, PyTorchModelHubMixin): # PyTorchModelHubMixin is new
    def __init__(self, config: Dict[str, Any]):
        super().__init__()
        # Initialize layers and parameters
        self.param = nn.Parameter(torch.rand(config["num_channels"], config["hidden_size"]))
        self.linear = nn.Linear(config["hidden_size"], config["num_classes"])
        self.config = config

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x + self.param)

PYTHON

# Create model instance with specific configuration
config = {"num_channels": 3, "hidden_size": 32, "num_classes": 10}
model = MyModel(config=config)
print(model)

PYTHON

# push to the hub
model.push_to_hub("my-awesome-model", config=config)

Verifying: To check your work, head back over to your Hugging Face and click your profile icon in the top-right of the website. Click “Profile” from there to view all of your uploaded models. Alternatively, you can search for your username (or model name) from the Model Hub.

Loading the model from Hugging Face

PYTHON

# reload
model = MyModel.from_pretrained("your-username/my-awesome-model")

Uploading transformer models to Hugging Face


Key Differences * Saving and Loading the Tokenizer: Transformer models require a tokenizer that needs to be saved and loaded with the model. This is not necessary for custom PyTorch models that typically do not require a separate tokenizer. * Using Pre-trained Classes: Transformer models use classes like AutoModelForSequenceClassification and AutoTokenizer from the transformers library, which are pre-built and designed for specific tasks (e.g., sequence classification). * Methods for Saving and Loading: The transformers library provides save_pretrained and from_pretrained methods for both models and tokenizers, which handle the serialization and deserialization processes seamlessly.

PYTHON

from transformers import AutoModelForSequenceClassification, AutoTokenizer

# Load a pre-trained model and tokenizer
model_name = "bert-base-uncased"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Save the model and tokenizer locally
model.save_pretrained("my_transformer_model")
tokenizer.save_pretrained("my_transformer_model")

# Load the model and tokenizer from the saved directory
loaded_model = AutoModelForSequenceClassification.from_pretrained("my_transformer_model")
loaded_tokenizer = AutoTokenizer.from_pretrained("my_transformer_model")

# Verify the loaded model and tokenizer
print(loaded_model)
print(loaded_tokenizer)

PYTHON

# Push the model and tokenizer to Hugging Face Hub
model.push_to_hub("my-awesome-transformer-model")
tokenizer.push_to_hub("my-awesome-transformer-model")

# Load the model and tokenizer from the Hugging Face Hub
hub_model = AutoModelForSequenceClassification.from_pretrained("user-name/my-awesome-transformer-model")
hub_tokenizer = AutoTokenizer.from_pretrained("user-name/my-awesome-transformer-model")

# Verify the model and tokenizer loaded from the hub
print(hub_model)
print(hub_tokenizer)

What pieces must be well-documented to ensure reproducible and responsible model sharing?

Discuss in small groups and report out: Why do you believe it is or isn’t important to share ML models? How has model-sharing contributed to your experiences or projects?

  • Environment setup
  • Training data
    • How the data was collected
    • Who owns the data: data license and usage terms
    • Basic descriptive statistics: number of samples, features, classes, etc.
    • Note any class imbalance or general bias issues
    • Description of data distribution to help prevent out-of-distribution failures.
  • Preprocessing steps.
    • Data splitting
    • Standardization method
    • Feature selection
    • Outlier detection and other filters
  • Model architecture, hyperparameters and, training procedure (e.g., dropout or early stopping)
  • Model weights
  • Evaluation metrics. Results and performance. The more tasks/datasets you can evaluate on, the better.
  • Ethical considerations: Include investigations of bias/fairness when applicable (i.e., if your model involves human data or affects decision-making involving humans)
  • Contact info
  • Acknowledgments
  • Examples and demos (highly recommended)

Document your model

TODO

TODO