Content from Overview


Last updated on 2024-10-17 | Edit this page

Estimated time: 31 minutes

Overview

Questions

  • What do we mean by “Trustworthy AI”?
  • How is this workshop structured, and what content does it cover?

Objectives

  • Define trustworthy AI and its various components.
  • Be prepared to dive into the rest of the workshop.

What is trustworthy AI?


Discussion

Take a moment to brainstorm what keywords/concepts come to mind when we mention “Trustworthy AI”. Share your thoughts with the class.

Artificial intelligence (AI) and machine learning (ML) are being used widely to improve upon human capabilities (either in speed/convenience/cost or accuracy) in a variety of domains: medicine, social media, news, marketing, policing, and more. It is important that the decisions made by AI/ML models uphold values that we, as a society, care about.

Trustworthy AI is a large and growing sub-field of AI that aims to ensure that AI models are trained and deployed in ways that are ethical and responsible.

The AI Bill of Rights


In October 2022, the Biden administration released a Blueprint for an AI Bill of Rights, a non-binding document that outlines how automated systems and AI should behave in order to protect Americans’ rights.

The blueprint is centered around five principles:

  • Safe and Effective Systems – AI systems should work as expected, and should not cause harm
  • Algorithmic Discrimination Protections – AI systems should not discriminate or produce inequitable outcomes
  • Data Privacy – data collection should be limited to what is necessary for the system functionality, and you should have control over how and if your data is used
  • Notice and Explanation – it should be transparent when an AI system is being used, and there should be an explanation of how particular decisions are reached
  • Human Alternatives, Consideration, and Fallback – you should be able to opt out of engaging with AI systems, and a human should be available to remedy any issues

This workshop


This workshop centers around four principles that are important to trustworthy AI: scientific validity, fairness, transparency, and accountability. We summarize each principle here.

Scientific validity

In order to be trustworthy, a model and its predictions need to be founded on good science. A model is not going to perform well if is not trained on the correct data, if it fits the underlying data poorly, or if it cannot recognize its own limitations. Scientific validity is closely linked to the AI Bill of Rights principle of “safe and effective systems”.

In this workshop, we cover the following topics relating to scientific validity:

  • Defining the problem (Preparing to Train a Model episode)
  • Training and evaluating a model, especially selecting an accuracy metric, avoiding over/underfitting, and preventing data leakage (Model Evaluation and Fairness episode)
  • Estimating model uncertainty (Estimating Model Uncertainty episode)
  • Out-of-distribution detection (OOD Detection episodes)

Fairness

As stated in the AI Bill of Rights, AI systems should not be discriminatory or produce inequitable outcomes. In the Model Evaluation and Fairness episode we discuss various definitions of fairness in the context of AI, and overview how model developers try to make their models more fair.

Transparency

Transparency – i.e., insight into how a model makes its decisions – is important for trustworthy AI, as we want models that make the right decisions for the right reasons. Transparency can be achieved via explanations or by using inherently interpretable models. We discuss transparency in the follow episodes:

  • Interpretability vs Explainability
  • Explainability Methods Overview
  • Explainability Methods: Deep Dive, Linear Probe, and GradCAM episodes

Accountability

Accountability is important for trustworthy AI because, inevitably, models will make mistakes or cause harm. Accountability is multi-faceted and largely non-technical, which is not to say unimportant, but just that it falls partially out of scope of this technical workshop.

We discuss two facets of accountability, model documentation and model sharing, in the Documenting and Releasing a Model episode.

For those who are interested, we recommend these papers to learn more about different aspects of AI accountability:

  1. Accountability of AI Under the Law: The Role of Explanation by Finale Doshi-Velez and colleagues. This paper discusses how explanations can be used in a legal context to determine accountability for harms caused by AI.
  2. Closing the AI accountability gap: defining an end-to-end framework for internal algorithmic auditing by Deborah Raji and colleagues proposes a framework for auditing algorithms. A key contribution of this paper is defining an auditing procedure over the whole model development and implementation pipeline, rather than narrowly focusing on the modeling stages.
  3. AI auditing: The Broken Bus on the Road to AI Accountability by Abeba Birhane and colleagues challenges previous work on AI accountability, arguing that most existing AI auditing systems are not effective. They propose necessary traits for effective AI audits, based on a review of existing practices.

Topics we do not cover

Trustworthy AI is a large, and growing, area of study. As of September 24, 2024, there are about 18,000 articles on Google Scholar that mention Trustworthy AI and were published in the first 9 months of 2024.

There are different Trustworthy AI methods for different types of models – e.g., decisions trees or linear models that are commonly used with tabular data, neural networks that are used with image data, or large multi-modal foundation models. In this workshop, we focus primarily on neural networks for the specific techniques we show in the technical implementations. That being said, much of the conceptual content is relevant to any model type.

Many of the topics we do not cover are sub-topics of the broad categories – e.g., fairness, explainability, or OOD detection – of the workshop and are important for specific use cases, but less relevant for a general audience. But, there are a few major areas of research that we don’t have time to touch on. We summarize a few of them here:

Data Privacy

In the US’s Blueprint for an AI Bill of Rights, one principle is data privacy, meaning that people should be aware how their data is being used, companies should not collect more data than they need, and people should be able to consent and/or opt out of data collection and usage.

A lack of data privacy poses several risks: first, whenever data is collected, it can be subject to data breaches. This risk is unavoidable, but collecting only the data that is truly necessary mitigates this risk, as does implementing safeguards to how data is stored and and accessed. Second, when data is used to train ML models, that data can sometimes be identifying by attackers. For instance, large language models like ChatGPT are known to release private data that was part of the training corpus when prompted in clever ways (see this blog post for more information).
Membership inference attacks, where an attacker determines whether a particular individual’s data was in the training corpus, are another vulnerability. These attacks may reveal things about a person directly (e.g., if the training dataset consisted of only people with a particular medical condition), or can be used to setup downstream attacks to gain more information.

There are several areas of active research relating to data privacy.

  • Differential privacy is a statistical technique that protects the privacy of individual data points. Models can be trained using differential privacy to provably prevent future attacks, but this currently comes at a high cost to accuracy.
  • Federated learning trains models using decentralized data from a variety of sources. Since the data is not shared centrally, there is less risk of data breaches or unauthorized data usage.

Generative AI risks

We touch on fairness issues with generative AI in the Model Evaluation and Fairness episode. But generative AI poses other risks, too, many of which are just starting to be researched and understood given how new widely-available generative AI is. We discuss one such risk, disinformation, briefly here:

  • Disinformation: A major risk of generative AI is the creation of misleading or fake and malicious content, often known as deep fakes. Deep fakes pose risks to individuals (e.g., creating content that harms an individual’s reputation) and society (e.g., fake news articles or pictures that look real).

Inline instructor notes can help inform instructors of timing challenges associated with the lessons. They appear in the “Instructor View”

Content from Preparing to train a model


Last updated on 2024-10-17 | Edit this page

Estimated time: 0 minutes

Overview

Questions

  • For what prediction tasks is machine learning an appropriate tool?
  • How can inappropriate target variable choice lead to suboptimal outcomes in a machine learning pipeline?
  • What forms of “bias” can occur in machine learning, and where do these biases come from?

Objectives

  • Judge what tasks are appropriate for machine learning
  • Understand why the choice of prediction task / target variable is important.
  • Describe how bias can appear in training data and algorithms.

Choosing appropriate tasks


Machine learning is a rapidly advancing, powerful technology that is helping to drive innovation. Before embarking on a machine learning project, we need to consider the task carefully. Many machine learning efforts are not solving problems that need to be solved. Or, the problem may be valid, but the machine learning approach makes incorrect assumptions and fails to solve the problem effectively. Worse, many applications of machine learning are not for the public good.

We will start by considering the NIH Guiding Principles for Ethical Research, which provide a useful set of considerations for any project.

Challenge

Take a look at the NIH Guiding Principles for Ethical Research.

What are the main principles?

A summary of the principles is listed below:

  • Social and clinical value: Does the social or clinical value of developing and implementing the model outweigh the risk and burden of the people involved?
  • Scientific validity: Once created, will the model provide valid, meaningful outputs?
  • Fair subject selection: Are the people who contribute and benefit from the model selected fairly, and not through vulnerability, privilege, or other unrelated factors?
  • Favorable risk-benefit ratio: Do the potential benefits of of developing and implementing the model outweigh the risks?
  • Independent review: Has the project been reviewed by someone independent of the project, and has an Institutional Review Board (IRB) been approached where appropriate?
  • Informed consent: Are participants whose data contributes to development and implementation of the model, as well as downstream recipients of the model, kept informed?
  • Respect for potential and enrolled subjects: Is the privacy of participants respected and are steps taken to continuously monitor the effect of the model on downstream participants?

AI tasks are often most controversial when they involve human subjects, and especially visual representations of people. We’ll discuss two case studies that use people’s faces as a prediction tool, and discuss whether these uses of AI are appropriate.

Case study 1: Physiognomy

In 2019, Nature Medicine published a paper that describes a model that can identify genetic disorders from a photograph of a patient’s face. The abstract of the paper is copied below:

Syndromic genetic conditions, in aggregate, affect 8% of the population. Many syndromes have recognizable facial features that are highly informative to clinical geneticists. Recent studies show that facial analysis technologies measured up to the capabilities of expert clinicians in syndrome identification. However, these technologies identified only a few disease phenotypes, limiting their role in clinical settings, where hundreds of diagnoses must be considered. Here we present a facial image analysis framework, DeepGestalt, using computer vision and deep-learning algorithms, that quantifies similarities to hundreds of syndromes.

DeepGestalt outperformed clinicians in three initial experiments, two with the goal of distinguishing subjects with a target syndrome from other syndromes, and one of separating different genetic sub-types in Noonan syndrome. On the final experiment reflecting a real clinical setting problem, DeepGestalt achieved 91% top-10 accuracy in identifying the correct syndrome on 502 different images. The model was trained on a dataset of over 17,000 images representing more than 200 syndromes, curated through a community-driven phenotyping platform. DeepGestalt potentially adds considerable value to phenotypic evaluations in clinical genetics, genetic testing, research and precision medicine.

  • What is the proposed value of the algorithm?
  • What are the potential risks?
  • Are you supportive of this kind of research?
  • What safeguards, if any, would you want to be used when developing and using this algorithm?
  • The algorithm could help doctors figure out what rare disease a patient has.
  • First, if the algorithm is used in the wrong hands, it could be used to discriminate against people with diseases. Second, if the algorithm is not accurate (false positive or false negative), trusting its results could lead to improper medical care.
  • Safeguards could include: requiring extensive testing to ensure that the algorithm maintains similar accuracy across racial and gender groups, making sure the algorithm is only accessible to medical professionals, and requiring follow-up testing to confirm the algorithm’s diagnosis.

Media reports about this paper were largely positive, e.g., reporting that clinicians are excited about the new technology.

Case study 2:

There is a long history of physiognomy, the “science” of trying to read someone’s character from their face. With the advent of machine learning, this discredited area of research has made a comeback. There have been numerous studies attempting to guess characteristics such as trustworthness, criminality, and political and sexual orientation.

In 2018, for example, researchers suggested that neural networks could be used to detect sexual orientation from facial images. The abstract is copied below:

We show that faces contain much more information about sexual orientation than can be perceived and interpreted by the human brain. We used deep neural networks to extract features from 35,326 facial images. These features were entered into a logistic regression aimed at classifying sexual orientation. Given a single facial image, a classifier could correctly distinguish between gay and heterosexual men in 81% of cases, and in 74% of cases for women. Human judges achieved much lower accuracy: 61% for men and 54% for women. The accuracy of the algorithm increased to 91% and 83%, respectively, given five facial images per person.

Facial features employed by the classifier included both fixed (e.g., nose shape) and transient facial features (e.g., grooming style). Consistent with the prenatal hormone theory of sexual orientation, gay men and women tended to have gender-atypical facial morphology, expression, and grooming styles. Prediction models aimed at gender alone allowed for detecting gay males with 57% accuracy and gay females with 58% accuracy. Those findings advance our understanding of the origins of sexual orientation and the limits of human perception. Additionally, given that companies and governments are increasingly using computer vision algorithms to detect people’s intimate traits, our findings expose a threat to the privacy and safety of gay men and women.

Discussion

Discuss the following questions.

  • What is the proposed value of the algorithm?
  • What are the potential risks?
  • Are you supportive of this kind of research?
  • What distinguishes this use of AI from the use of AI described in Case Study 1?
  • The algorithm proposes detecting an individual’s sexual orientation from images of their face. It is unclear why this is something that needs to be algorithmically detected by any entity.
  • If the algorithm is used by anti-LGBTQ entities, it could be used to harass members of the LGBTQ community (as well as non-LGBTQ people who are flagged as LGBTQ by the algorithm). If the algorithm is accurate, it could be used to “out” individuals who are gay but do not want to publicly share that information.
  • The first case study aims to detect disease. The implication of this research – at least, as suggested by the linked article – is that it can help doctors with diagnosis and give individuals affected by the disease access to treatment. Conversely, if there is a medical reason for knowing someone’s sexual orientation, it is not necessary to use AI – the doctor can just ask the patient.

Media reports of this algorithm were largely negative, with a Scientific American article highlighting the connections to physiognomy and raising concern over government use of these algorithms:

This is precisely the kind of “scientific” claim that can motivate repressive governments to apply AI algorithms to images of their citizens. And what is it to stop them from “reading” intelligence, political orientation and criminal inclinations from these images?

Choosing the outcome variable


Sometimes, choosing the outcome variable is easy: for instance, when building a model to predict how warm it will be out tomorrow, the temperature can be the outcome variable because it’s measurable (i.e., you know what temperature it was yesterday and today) and your predictions won’t cause a feedback loop (e.g., given a set of past weather data, the weather next Monday won’t change based on what your model predicts tomorrow’s temperature to be).

By contrast, sometimes it’s not possible to measure the target prediction subject directly, and sometimes predictions can cause feedback loops.

Case Study: Proxy variables

Consider the scenario described in the challenge below.

Challenge

Suppose that you work for a hospital and are asked to build a model to predict which patients are high-risk and need extra care to prevent negative health outcomes.

Discuss the following with a partner or small group: 1. What is the goal target variable? 2. What are challenges in measuring the target variable in the training data (i.e., former patients)? 3. Are there other variables that are easier to measure, but can approximate the target variable, that could serve as proxies? 3. How do social inequities interplay with the value of the target variable versus the value of the proxies?

  1. The goal target variable is healthcare need.
  2. Patients’ healthcare needs are unknown unless they have seen a doctor. Patients who have less access to medical care may have under-documented needs. There might also be differences between doctors or hospital systems in how health conditions are documented.
  3. In the US, there are standard ways that healthcare billing information needs to be documented (this information may be more standardized than medical conditions). There may be more complete data from acute medical emergencies (i.e., emergency room visits) than there is for chronic conditions.
  4. In the US, healthcare access is often tied to employment, which means that wealthier people (who in the US, also tend to be white) have more access to healthcare.

The “challenge” scenario is not hypothetical: A well-known study by Obermeyer et al. analyzed an algorithm that hospitals used to assign patients risk scores for various conditions. The algorithm had access to various patient data, such as demographics (e.g., age and sex), the number of chronic conditions, insurance type, diagnoses, and medical costs. The algorithm did not have access to the patient’s race. The patient risk score determined the level of care the patient should receive, with higher-risk patients receiving additional care.

Ideally, the target variable would be health needs, but this can be challenging to measure: how do you compare the severity of two different conditions? Do you count chronic and acute conditions equally? In the system described by Obermeyer et al., the hospital decided to use health-care costs as a proxy for health needs, perhaps reasoning that this data is at least standardized across patients and doctors.

However, Obermeyer et al. reveal that the algorithm is biased against Black patients. That is, if there are two individuals – one white and one Black – with equal health, the algorithm tends to assign a higher risk score to the white patient, thus giving them access to higher care quality. The authors blame the choice of proxy variable for the racial disparities.

The authors go on to describe how, due to how health-care access is structured in the US, richer patients have more healthcare expenses, even if they are equally (un)healthy to a lower-income patient. The richer patients are also more likely to be white.

Consider the following:

  • How could the algorithm developers have caught this problem earlier?
  • Is this a technical mistake or a process-based mistake? Why?
  • The algorithm developers could have tested specifically checked for racial bias in their solution.
  • Discussion about the choice of target variable, including implementing models with different targets and comparing the results, could have exposed the bias.
  • Possibly a more diverse development team – e.g., including individuals who have firsthand experience with struggling to access healthcare – would have spotted the issue.
  • Note that it’s easier to see the mistake in hindsight than in the moment.

Case study: Feedback loop

Consider social media, like Instagram or TikTok’s “for you page” or Facebook or Twitter’s newsfeed. The algorithms that determine what to show are complex (and proprietary!) but a large part of the algorithms’ objective is engagement: the number of clicks, views, or re-posts. For instance, this focus on engagement can create an “echo chamber” where individual users solely see content that aligns with their political ideology, thereby maximizing the positive engagement with each post. But the impact of social media feedback loops spreads beyond politics: researchers have explored how similar feedback loops exist for mental health conditions such as eating disorders. If someone finds themselves in this area of social media, it’s likely because they have, or have risk factors for, an eating disorder, and seeing pro-eating disorder content can drive engagement, but ultimately be very bad for mental health.

Consider the following questions:

  • Why do social media companies optimize for engagement?
  • What would be an alternative optimization target? How would the outcomes differ, both for users and for the companies’ profits?
  • Social media companies optimize for engagement to maximize profits: if users keep using the service, they can sell more ads and bring in more revenue.
  • It’s hard to come up with an alternate optimization target that could be easily operationalized. Alternate goals could be social connection, learning, broadening one’s worldview, or even entertainment. But how would any of these goals be measured?

Recap - Choosing the right outcome variable: Sometimes, choosing the outcome variable is straightforward, like predicting tomorrow’s temperature. Other times, it gets tricky, especially when we can’t directly measure what we want to predict. It’s important to choose the right outcome variable because this decision plays a crucial role in ensuring our models are trustworthy, “fair” (more on this later), and unbiased. A poor choice can lead to biased results and unintended consequences, making it harder for our models to be effective and reliable.

Understanding bias


Now that we’ve covered the importance of the outcome variable, let’s talk about bias. Bias can show up in various ways during the modeling process, impacting our results and fairness. If we don’t consider bias from the beginning, we risk creating models that don’t work well for everyone or that reinforce existing inequalities.

So, what exactly do we mean by bias? The term is a little overloaded and can refer to different things depending on context. However, there are two general types/definitions of bias:

  • (Statistical) bias: This refers to the tendency of an algorithm to produce one solution over another, even when other options may be just as good or better. Statistical bias can arise from several sources (discussed below), including how data is collected and processed.
  • (Social) bias: outcomes are unfair to one or more social groups. Social bias can be the result of statistical bias (i.e., an algorithm giving preferential treatment to one social group over others), but can also occur outside of a machine learning context.

Sources of statistical bias

Algorithmic bias

Algorithmic bias is the tendency of an algorithm to favor one solution over another. Algorithmic bias is not always bad, and may sometimes be encoded for by algorithm developers. For instance, linear regression with L0-regularization displays algorithmic bias towards sparse classifiers (i.e., classifiers where most weights are 0). This bias may be desirable in settings where human interpretability is important.

But algorithmic bias can also occur unintentionally: for instance, if there is data bias (described below), this may lead algorithm developers to select an algorithm that is ill-suited to underrepresented groups. Then, even if the data bias is rectified, sticking with the original algorithm choice may not fix biased outcomes.

Data bias:

Data bias is when the available training data is not accurate or representative of the target population. Data bias is extremely common (it’s often hard to collect perfectly-representative, and perfectly-accurate data), and care arise in multiple ways:

  • Measurement error - if a tool is not well calibrated, measurements taken by that tool won’t be accurate. Likewise, human biases can lead to measurement error, for instance, if people systematically over-report their height on dating apps, or if doctors do not believe patient’s self-reports of their pain levels.
  • Response bias - for instance, when conducting a survey about customer satisfaction, customers who had very positive or very negative experiences may be more likely to respond.
  • Representation bias - the data is not well representative of the whole population. For instance, doing clinical trials primarily on white men means that women and other races are not well represented in data.

Through the rest of this lesson, if we use the term “bias” without any additional context, we will be referring to social bias that stems from statistical bias.

Case Study

With a partner or small group, choose one of the three case study options. Read or watch individually, then discuss as a group how bias manifested in the training data, and what strategies could correct for it.

After the discussion, share with the whole workshop what you discussed.

  1. Predictive policing
  2. Facial recognition (video, 5 min.)
  3. Amazon hiring tool
  1. Policing data does not provide a complete picture of crime: it only contains data about crimes that are reported. Some neighborhoods (in the US, usually poor neighborhoods with predominantly Black and Brown residents) are over-policed relative to other neighborhoods. As a result, data will suggest that the over-policed neighborhoods have more crime, and then will send more officers to patrol those areas, resulting in a feedback loop. Using techniques to clean and balance the data could help, but the article’s authors also point towards using non-technical solutions, such as algorithmic accountability and oversight frameworks.
  2. Commercially-available facial recognition systems have much higher accuracies on white men than on darker-skinned women. This discrepancy is attributed to imbalances in the training data. This problem could have been avoided if development teams were more diverse: e.g., if someone thought to evaluate the model on darker-skinned people during the development process. Then, collecting more data from underrepresented groups could improve accuracy on those individuals.
  3. Amazon tried to automate the resume-screening part of its hiring process, relying on data (e.g., resumes) from existing employees. However, the AI learned to discriminate against women because Amazon’s existing technical staff skewed heavily male. This could have been avoided in a couple ways: first, if Amazon did not have an existing gender skew, the data would have been cleaner. Second, given the gender skew in Amazon’s employees, model developers could have built in safeguards, e.g., mechanisms to satisfy some notion of fairness, such as deciding to interview an equal proportion of male and female job applicants.

Key Points

  • Some tasks are not appropriate for machine learning due to ethical concerns.
  • Machine learning tasks should have a valid prediction target that maps clearly to the real-world goal.
  • Training data can be biased due to societal inequities, errors in the data collection process, and lack of attention to careful sampling practices.
  • “Bias” also refers to statistical bias, and certain algorithms can be biased towards some solutions.

Content from Model evaluation and fairness


Last updated on 2024-10-15 | Edit this page

Estimated time: 0 minutes

Overview

Questions

  • What metrics do we use to evaluate models?
  • What are some common pitfalls in model evaluation?
  • How do we define fairness and bias in machine learning outcomes?
  • What types of bias and unfairness can occur in generative AI?
  • What techniques exist to improve the fairness of ML models?

Objectives

  • Reason about model performance through standard evaluation metrics.
  • Recall how underfitting, overfitting, and data leakage impact model performance.
  • Understand and distinguish between various notions of fairness in machine learning.
  • Understand general approaches for improving the fairness of ML models.

Accuracy metrics


Stakeholders often want to know the accuracy of a machine learning model – what percent of predictions are correct? Accuracy can be decomposed into further metrics: e.g., in a binary prediction setting, recall (the fraction of positive samples that are classified correctly) and precision (the fraction of samples classified as positive that actually are positive) are commonly-used metrics.

Suppose we have a model that performs binary classification (+, -) on a test dataset of 1000 samples (let \(n\)=1000). A confusion matrix defines how many predictions we make in each of four quadrants: true positive with positive prediction (++), true positive with negative prediction (+-), true negative with positive prediction (-+), and true negative with negative prediction (–).

True + True -
Predicted + 300 80
Predicted - 25 595

So, for instance, 80 samples have a true class of + but get predicted as members of -.

We can compute the following metrics:

  • Accuracy: What fraction of predictions are correct?
    • (300 + 595) / 100 = 0.895
    • Accuracy is 89.5%
  • Precision: What fraction of predicted positives are true positives?
    • 300 / (300 + 80) = 0.789
    • Precision is 78.9%
  • Recall: What fraction of true positives are classified as positive?
    • 300 / (300 + 25) = 0.923
    • Recall is 92.3%

Callout

We’ve discussed binary classification but for other types of tasks there are different metrics. For example,

  • Multi-class problems often use Top-K accuracy, a metric of how often the true response appears in their top-K guesses.
  • Regression tasks often use the Area Under the ROC curve (AUC ROC) as a measure of how well the classifier performs at different thresholds.

What accuracy metric to use?

Different accuracy metrics may be more relevant in different situations. Discuss with a partner or small groups whether precision, recall, or some combination of the two is most relevant in the following prediction tasks:

  1. Deciding what patients are high risk for a disease and who should get additional low-cost screening.
  2. Deciding what patients are high risk for a disease and should start taking medication to lower the disease risk. The medication is expensive and can have unpleasant side effects.
  1. It is best if all patients who need the screening get it, and there is little downside for doing screenings unnecessarily because the screening costs are low. Thus, a high recall score is optimal.

  2. Given the costs and side effects of the medicine, we do not want patients not at risk for the disease to take the medication. So, a high precision score is ideal.

Model evaluation pitfalls


Overfitting and underfitting

Overfitting is characterized by worse performance on the test set than on the train set and can be fixed by switching to a simpler model architecture or by adding regularization.

Underfitting is characterized by poor performance on both the training and test datasets. It can be fixed by collecting more training data, switching to a more complex model architecture, or improving feature quality.

graphs of overfitting and underfitting
Example of overfitting/underfitting

If you need a refresher on how to detect overfitting and underfitting in your models, this article is a good resource.

Data Leakage

Data leakage occurs when the model has access to the test data during training and results in overconfidence in the model’s performance.

Recent work by Sayash Kapoor and Arvind Narayanan shows that data leakage is incredibly widespread in papers that use ML across several scientific fields. They define 8 common ways that data leakage occurs, including:

  1. No test set: there is no hold-out test-set, rather, the model is evaluated on a subset of the training data. This is the “obvious,” canonical example of data leakage.
  2. Preprocessing on whole dataset: when preprocessing occurs on the train + test sets, rather than just the train set, the model learns information about the test set that it should not have access to until later. For instance, missing feature imputation based on the full dataset will be different than missing feature imputation based only on the values in the train dataset.
  3. Illegitimate features: sometimes, there are features that are proxies for the outcome variable. For instance, if the goal is to predict whether a patient has hypertension, including whether they are on a common hypertension medication is data leakage since future, new patients would not already be on this medication.
  4. Temporal leakage: if the model predicts a future outcome, the train set should contain information from the future. For instance, if the task is to predict whether a patient will develop a particular disease within 1 year, the dataset should not contain data points for the same patient from multiple years.

Measuring fairness


What does it mean for a machine learning model to be fair or unbiased? There is no single definition of fairness, and we can talk about fairness at several levels (ranging from training data, to model internals, to how a model is deployed in practice). Similarly, bias is often used as a catch-all term for any behavior that we think is unfair. Even though there is no tidy definition of unfairness or bias, we can use aggregate model outputs to gain an overall understanding of how models behave with respect to different demographic groups – an approach called group fairness.

In general, if there are no differences between groups in the real world (e.g., if we lived in a utopia with no racial or gender gaps), achieving fairness is easy. But, in practice, in many social settings where prediction tools are used, there are differences between groups, e.g., due to historical and current discrimination.

For instance, in a loan prediction setting in the United States, the average white applicant may be better positioned to repay a loan than the average Black applicant due to differences in generational wealth, education opportunities, and other factors stemming from anti-Black racism. Suppose that a bank uses a machine learning model to decide who gets a loan. Suppose that 50% of white applicants are granted a loan, with a precision of 90% and a recall of 70% – in other words, 90% of white people granted loans end up repaying them, and 70% of all people who would have repaid the loan, if given the opportunity, get the loan. Consider the following scenarios:

  • (Demographic parity) We give loans to 50% of Black applicants in a way that maximizes overall accuracy
  • (Equalized odds) We give loans to X% of Black applicants, where X is chosen to maximize accuracy subject to keeping precision equal to 90%.
  • (Group level calibration) We give loans to X% of Black applicants, where X is chosen to maximize accuracy while keeping recall equal to 70%.

There are many notions of statistical group fairness, but most boil down to one of the three above options: demographic parity, equalized odds, and group-level calibration. All three are forms of distributional (or outcome) fairness. Another dimension, though, is procedural fairness: whether decisions are made in a just way, regardless of final outcomes. Procedural fairness contains many facets, but one way to operationalize it is to consider individual fairness (also called counterfactual fairness), which was suggested in 2012 by Dwork et al. as a way to ensure that “similar individuals [are treated] similarly”. For instance, if two individuals differ only on their race or gender, they should receive the same outcome from an algorithm that decides whether to approve a loan application.

In practice, it’s hard to use individual fairness because defining a complete set of rules about when two individuals are sufficiently “similar” is challenging.

Matching fairness terminology with definitions

Match the following types of formal fairness with their definitions. (A) Individual fairness, (B) Equalized odds, (C) Demographic parity, and (D) Group-level calibration

  1. The model is equally accurate across all demographic groups.
  2. Different demographic groups have the same true positive rates and false positive rates.
  3. Similar people are treated similarly.
  4. People from different demographic groups receive each outcome at the same rate.

A - 3, B - 2, C - 4, D - 1

But some types of unfairness cannot be directly measured by group-level statistical data. In particular, generative AI opens up new opportunities for bias and unfairness. Bias can occur through representational harms (e.g., creating content that over-represents one population subgroup at the expense of another), or through stereotypes (e.g., creating content that reinforces real-world stereotypes about a group of people). We’ll discuss some specific examples of bias in generative models next.

Fairness in generative AI


Generative models learn from statistical patterns in real-world data. These statistical patterns reflect instances of bias in real-world data - what data is available on the internet, what stereotypes does it reinforce, and what forms of representation are missing?

Natural language

One set of social stereotypes that large AI models can learn is gender based. For instance, certain occupations are associated with men, and others with women. For instance, in the U.S., doctors are historically and stereotypically usually men.

In 2016, Caliskan et al. showed that machine translation systems exhibit gender bias, for instance, by reverting to stereotypical gendered pronouns in ambiguous translations, like in Turkish – a language without gendered pronouns – to English.

In response, Google tweaked their translator algorithms to identify and correct for gender stereotypes in Turkish and several other widely-spoken languages. So when we repeat a similar experiment today, we get the following output:

Screenshot of Google Translate output. The English sentence "The doctor is on her lunch break" is translated to Turkish, and then the Turkish output is translated back to English as either "The doctor is on his lunch break" or "The doctor is on his lunch break".
Turkish Google Translate example (screenshot from 1/9/2024)

But for other, less widely-spoken languages, the original problem persists:

Screenshot of Google Translate output. The English sentence "The doctor is on her lunch break" is translated to Norwegian, and then the Norwegian output is translated back to English as "The doctor is on his lunch break".
Norwegian Google Translate example (screenshot from 1/9/2024)

We’re not trying to slander Google Translate here – the translation, without additional context, is ambiguous. And even if they extended the existing solution to Norwegian and other languages, the underlying problem (stereotypes in the training data) still exists. And with generative AI such as ChatGPT, the problem can be even more pernicious.

Red-teaming large language models

In cybersecurity, “red-teaming” is when well-intentioned people think like a hacker in order to make a system safer. In the context of Large Language Models (LLMs), red-teaming is used to try to get LLMs to output offensive, inaccurate, or unsafe content, with the goal of understanding the limitations of the LLM.

Try out red-teaming with ChatGPT or another LLM. Specifically, can you construct a prompt that causes the LLM to output stereotypes? Here are some example prompts, but feel free to get creative!

  • “Tell me a story about a doctor” (or other profession with gender)

  • If you speak a language other than English, how does are ambiguous gendered pronouns handled? For instance, try the prompt “Translate ‘The doctor is here’ to Spanish”. Is a masculine or feminine pronoun used for the doctor in Spanish?

If you use LLMs in your research, consider whether any of these issues are likely to be present for your use cases. If you do not use LLMs in your research, consider how these biases can affect downstream uses of the LLM’s output.

Most publicly-available LLM providers set up guardrails to avoid propagating biases present in their training data. For instance, as of the time of this writing (January 2024), the first suggested prompt, “Tell me a story about a doctor,” consistently creates a story about a woman doctor. Similarly, substituting other professions that have strong associations with men for “doctor” (e.g., “electrical engineer,” “garbage collector,” and “US President”) yield stories with female or gender-neutral names and pronouns.

Discussing other fairness issues

If you use LLMs in your research, consider whether any of these issues are likely to be present for your use cases. Share your thoughts in small groups with other workshop participants.

Image generation

The same problems that language modeling face also affect image generation. Consider, for instance, Melon et al. developed an algorithm called Pulse that can convert blurry images to higher resolution. But, biases were quickly unearthed and shared via social media.

Challenge

Who is shown in this blurred picture? blurry image of Barack Obama

While the picture is of Barack Obama, the upsampled image shows a white face. Unblurred version of the pixelated picture of Obama. Instead of showing Obama, it shows a white man.

You can try the model here.

Menon and colleagues subsequently updated their paper to discuss this issue of bias. They assert that the problems inherent in the PULSE model are largely a result of the underlying StyleGAN model, which they had used in their work.

Overall, it seems that sampling from StyleGAN yields white faces much more frequently than faces of people of color … This bias extends to any downstream application of StyleGAN, including the implementation of PULSE using StyleGAN.

Results indicate a racial bias among the generated pictures, with close to three-fourths (72.6%) of the pictures representing White people. Asian (13.8%) and Black (10.1%) are considerably less frequent, while Indians represent only a minor fraction of the pictures (3.4%).

These remarks get at a central issue: biases in any building block of a system (data, base models, etc.) get propagated forwards. In generative AI, such as text-to-image systems, this can result in representational harms, as documented by Bianchi et al. Fixing these issues of bias is still an active area of research. One important step is to be careful in data collection, and try to get a balanced dataset that does not contain harmful stereotypes. But large language models use massive training datasets, so it is not possible to manually verify data quality. Instead, researchers use heuristic approaches to improve data quality, and then rely on various techniques to improve models’ fairness, which we discuss next.

Improving fairness of models


Model developers frequently try to improve the fairness of there model by intervening at one of three stages: pre-processing, in-processing, or post-processing. We’ll cover techniques within each of these paradigms in turn.

We start, though, by discussing why removing the sensitive attribute(s) is not sufficient. Consider the task of deciding which loan applicants are funded. Suppose we are concerned with racial bias in the model outputs. If we remove race from the set of attributes available to the model, the model cannot make overly racist decisions. However, it could instead make decisions based on zip code, which in the US is a very good proxy for race.

Can we simply remove all proxy variables? We could likely remove zip code, if we cannot identify a causal relationship between where someone lives and whether they will be able to repay a loan. But what about an attribute like educational achievement? Someone with a college degree (compared with someone with, say, less than a high school degree) has better employment opportunities and therefore might reasonably be expected to be more likely to be able to repay a loan. However, educational attainment is still a proxy for race in the United States due to historical (and ongoing) discrimination.

Pre-processing generally modifies the dataset used for learning. Techniques in this category include:

  • Oversampling/undersampling: instead of training a machine learning model on all of the data, undersample the majority class by removing some of the majority class samples from the dataset in order to have a more balanced dataset. Alternatively, oversample the minority class by duplicating samples belonging to this group.

  • Data augmentation: the number of samples from minority groups may be increased by generating synthetic data with a generative adversarial network (GAN). We won’t cover this method in this workshop (using a GAN can be more computationally expensive than other techniques). If you’re interested, you can learn more about this method from the paper Inclusive GAN: Improving Data and Minority Coverage in Generative Models.

  • Changing feature representations: various techniques have been proposed to increase fairness by removing unfairness from the data directly. To do so, the data is converted into an alternate representation so that differences between demographic groups are minimized, yet enough information is maintained in order to be able to learn a model that performs well. An advantage of this method is that it is model-agnostic, however, a challenge is it reduces the interpretability of interpretable models and makes post-hoc explainability less meaningful for black-box models.

Pros and cons of preprocessing options

Discuss what you think the pros and cons of the different pre-processing options are. What techniques might work better in different settings?

A downside of oversampling is that it may violate statistical assumptions about independence of samples. A downside of undersampling is that the total amount of data is reduced, potentially resulting in models that perform less well overall.

A downside of using GANs to generate additional data is that this process may be expensive and require higher levels of ML expertise.

A challenge with all techniques is that if there is not sufficient data from minority groups, it may be hard to achieve good performance on the groups without simply collecting more or higher-quality data.

In-processing modifies the learning algorithm. Some specific in-processing techniques include:

  • Reweighting samples: many machine learning models allow for reweighting individual samples, i.e., indicating that misclassifying certain, rarer, samples should be penalized more severely in the loss function. In the code example, we show how to reweight samples using AIF360’s Reweighting function.

  • Incorporating fairness into the loss function: reweighting explicitly instructs the loss function to penalize the misclassification of certain samples more harshly. However, another option is to add a term to the loss function corresponding to the fairness metric of interest.

Post-processing modifies an existing model to increase its fairness. Techniques in this category often compute a custom threshold for each demographic group in order to satisfy a specific notion of group fairness. For instance, if a machine learning model for a binary prediction task uses 0.5 as a cutoff (e.g., raw scores less than 0.5 get a prediction of 0 and others get a prediction of 1), fair post-processing techniques may select different thresholds, e.g., 0.4 or 0.6 for different demographic groups.

In the next episode, we explore two different bias mitigations strategies implemented in the AIF360 Fairness Toolkit.

Content from Model fairness: hands-on


Last updated on 2024-11-19 | Edit this page

Estimated time: 0 minutes

Overview

Questions

  • How can we use AI Fairness 360 – a common toolkit – for measuring and improving model fairness?

Objectives

  • Describe and implement two different ways of modifying the machine learning modeling process to improve the fairness of a model.

In this episode, we will explore, hands-on, how to measure and improve fairness of ML models.

This notebook is adapted from AIF360’s Medical Expenditure Tutorial.

The tutorial uses data from the Medical Expenditure Panel Survey. We include a short description of the data below. For more details, especially on the preprocessing, please see the AIF360 tutorial.

To begin, we’ll import some generally-useful packages.

PYTHON

# import numpy
import numpy as np

# import Markdown for nice display
from IPython.display import Markdown, display

# import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt

# import defaultdict (we'll use this instead of dict because it allows us to initialize a dictionary with a default value)
from collections import defaultdict

Scenario and data


The goal is to develop a healthcare utilization scoring model – i.e., to predict which patients will have the highest utilization of healthcare resources.

The original dataset contains information about various types of medical visits; the AIF360 preprocessing created a single output feature ‘UTILIZATION’ that combines utilization across all visit types. Then, this feature is binarized based on whether utilization is high, defined as >= 10 visits. Around 17% of the dataset has high utilization.

The sensitive feature (that we will base fairness scores on) is defined as race. Other predictors include demographics, health assessment data, past diagnoses, and physical/mental limitations.

The data is divided into years (we follow the lead of AIF360’s tutorial and use 2015), and further divided into Panels. We use Panel 19 (the first half of 2015).

Loading the data

Before starting, make sure you have downloaded the data as described in the setup instructions.

First, we need to import the dataset from the AI Fairness 360 library. Then, we can load in the data and create the train/validation/test splits. The rest of the code in the following blocks sets up information about the privileged and unprivileged groups. (Recall, we focus on race as the sensitive feature.)

PYTHON

from aif360.datasets import MEPSDataset19 # import the dataset

PYTHON

# assign train, validation, and test data. 
# Split the data into 50% train, 30% val, and 20% test
(dataset_orig_panel19_train,
 dataset_orig_panel19_val,
 dataset_orig_panel19_test) = MEPSDataset19().split([0.5, 0.8], shuffle=True) 

sens_ind = 0 # sensitive attribute index is 0
sens_attr = dataset_orig_panel19_train.protected_attribute_names[sens_ind] # sensitive attribute name

# find the attribute values that correspond to the privileged and unprivileged groups
unprivileged_groups = [{sens_attr: v} for v in
                       dataset_orig_panel19_train.unprivileged_protected_attributes[sens_ind]]
privileged_groups = [{sens_attr: v} for v in
                     dataset_orig_panel19_train.privileged_protected_attributes[sens_ind]]

Check object type.

PYTHON

type(dataset_orig_panel19_train)

Preview data.

PYTHON

dataset_orig_panel19_train.convert_to_dataframe()[0].head()

Show details about the data.

PYTHON

def describe(train:MEPSDataset19=None, val:MEPSDataset19=None, test:MEPSDataset19=None) -> None:
    '''
        Print information about the test dataset (and train and validation dataset, if 
        provided). Prints the dataset shape, favorable and unfavorable labels, 
        protected attribute names, and feature names.
    '''
    if train is not None:
        display(Markdown("#### Training Dataset shape"))
        print(train.features.shape) # print the shape of the training dataset - should be (7915, 138)
    if val is not None:
        display(Markdown("#### Validation Dataset shape"))
        print(val.features.shape)
    display(Markdown("#### Test Dataset shape"))
    print(test.features.shape)
    display(Markdown("#### Favorable and unfavorable labels"))
    print(test.favorable_label, test.unfavorable_label) # print favorable and unfavorable labels. Should be 1, 0
    display(Markdown("#### Protected attribute names"))
    print(test.protected_attribute_names) # print protected attribute name, "RACE"
    display(Markdown("#### Privileged and unprivileged protected attribute values"))
    print(test.privileged_protected_attributes, 
          test.unprivileged_protected_attributes) # print protected attribute values. Should be [1, 0]
    display(Markdown("#### Dataset feature names\n See [MEPS documentation](https://meps.ahrq.gov/data_stats/download_data/pufs/h181/h181doc.pdf) for details on the various features"))
    print(test.feature_names) # print feature names

describe(dataset_orig_panel19_train, dataset_orig_panel19_val, dataset_orig_panel19_test) # call our function "describe"

Next, we will look at whether the dataset contains bias; i.e., does the outcome ‘UTILIZATION’ take on a positive value more frequently for one racial group than another?

To check for biases, we will use the BinaryLabelDatasetMetric class from the AI Fairness 360 toolkit. This class creates an object that – given a dataset and user-defined sets of “privileged” and “unprivileged” groups – can compute various fairness scores. We will call the function MetricTextExplainer (also in AI Fairness 360) on the BinaryLabelDatasetMetric object to compute the disparate impact. The disparate impact score will be between 0 and 1, where 1 indicates no bias and 0 indicates extreme bias. In other words, we want a score that is close to 1, because this indicates that different demographic groups have similar outcomes under the model. A commonly used threshold for an “acceptable” disparate impact score is 0.8, because under U.S. law in various domains (e.g., employment and housing), the disparate impact between racial groups can be no larger than 80%.

PYTHON

# import MetricTextExplainer to be able to print descriptions of metrics
from aif360.explainers import MetricTextExplainer 

Some initial import error may occur since we’re using the CPU-only version of torch. If you run the import statement twice it should correct itself. We’ve coded this as a try/except statement below.

PYTHON

# import BinaryLabelDatasetMetric (class of metrics) 
try:
    from aif360.metrics import BinaryLabelDatasetMetric
except OSError as e:
    print(f"First import failed: {e}. Retrying...")
    from aif360.metrics import BinaryLabelDatasetMetric

print("Import successful!")

PYTHON

metric_orig_panel19_train = BinaryLabelDatasetMetric(
        dataset_orig_panel19_train, # train data
        unprivileged_groups=unprivileged_groups, # pass in names of unprivileged and privileged groups
        privileged_groups=privileged_groups)
explainer_orig_panel19_train = MetricTextExplainer(metric_orig_panel19_train) # create a MetricTextExplainer object

print(explainer_orig_panel19_train.disparate_impact()) # print disparate impact

We see that the disparate impact is about 0.53, which means the privileged group has the favorable outcome at about 2x the rate as the unprivileged group does.

(In this case, the “favorable” outcome is label=1, i.e., high utilization) ## Train a model

We will train a logistic regression classifier. To do so, we have to import various functions from sklearn: a scaler, the logistic regression class, and make_pipeline.

PYTHON

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline # allows to stack modeling steps
from sklearn.pipeline import Pipeline # allow us to reference the Pipeline object type

PYTHON

dataset = dataset_orig_panel19_train # use the train dataset
model = make_pipeline(StandardScaler(), # scale the data to have mean 0 and variance 1
                      LogisticRegression(solver='liblinear', 
                                         random_state=1) # logistic regression model
                    )
fit_params = {'logisticregression__sample_weight': dataset.instance_weights} # use the instance weights to fit the model

lr_orig_panel19 = model.fit(dataset.features, dataset.labels.ravel(), **fit_params) # fit the model

Validate the model

We want to validate the model – that is, check that it has good accuracy and fairness when evaluated on the validation dataset. (By contrast, during training, we only optimize for accuracy and fairness on the training dataset.)

Recall that a logistic regression model can output probabilities (i.e., model.predict(dataset).scores) and we can determine our own threshold for predicting class 0 or 1. One goal of the validation process is to select the threshold for the model, i.e., the value v so that if the model’s output is greater than v, we will predict the label 1.

The following function, test, computes performance on the logistic regression model based on a variety of thresholds, as indicated by thresh_arr, an array of threshold values. The threshold values we test are determined through the function np.linspace. We will continue to focus on disparate impact, but all other metrics are described in the AIF360 documentation.

PYTHON

# Import the ClassificationMetric class to be able to compute metrics for the model
from aif360.metrics import ClassificationMetric

PYTHON

def test(dataset: MEPSDataset19, model:Pipeline, thresh_arr: np.ndarray) -> dict: 
    ''' 
        Given a dataset, model, and list of potential cutoff thresholds, compute various metrics
        for the model. Returns a dictionary of the metrics, including balanced accuracy, average odds
        difference, disparate impact, statistical parity difference, equal opportunity difference, and
        theil index.
    '''
    try:
        # sklearn classifier
        y_val_pred_prob = model.predict_proba(dataset.features) # get the predicted probabilities
    except AttributeError as e:
        print(e)
        # aif360 inprocessing algorithm
        y_val_pred_prob = model.predict(dataset).scores # get the predicted scores
        pos_ind = 0
        
    pos_ind = np.where(model.classes_ == dataset.favorable_label)[0][0] # get the index corresponding to the positive class
    metric_arrs = defaultdict(list) # create a dictionary to store the metrics
    
    # repeat the following for each potential cutoff threshold
    for thresh in thresh_arr:
        y_val_pred = (y_val_pred_prob[:, pos_ind] > thresh).astype(np.float64) # get the predicted labels

        dataset_pred = dataset.copy() # create a copy of the dataset
        dataset_pred.labels = y_val_pred # assign the predicted labels to the new dataset
        metric = ClassificationMetric( # create a ClassificationMetric object
                dataset, dataset_pred,
                unprivileged_groups=unprivileged_groups,
                privileged_groups=privileged_groups)

        # various metrics - can look up what they are on your own
        metric_arrs['bal_acc'].append((metric.true_positive_rate()
                                     + metric.true_negative_rate()) / 2) # balanced accuracy
        metric_arrs['avg_odds_diff'].append(metric.average_odds_difference()) # average odds difference
        metric_arrs['disp_imp'].append(metric.disparate_impact()) # disparate impact
        metric_arrs['stat_par_diff'].append(metric.statistical_parity_difference()) # statistical parity difference
        metric_arrs['eq_opp_diff'].append(metric.equal_opportunity_difference()) # equal opportunity difference
        metric_arrs['theil_ind'].append(metric.theil_index()) # theil index
    
    return metric_arrs

PYTHON

thresh_arr = np.linspace(0.01, 0.5, 50) # create an array of 50 potential cutoff thresholds ranging from 0.01 to 0.5
val_metrics = test(dataset=dataset_orig_panel19_val, 
                   model=lr_orig_panel19,
                   thresh_arr=thresh_arr) # call our function "test" with the validation data and lr model
lr_orig_best_ind = np.argmax(val_metrics['bal_acc']) # get the index of the best balanced accuracy

We will plot val_metrics. The x-axis will be the threshold we use to output the label 1 (i.e., if the raw score is larger than the threshold, we output 1).

The y-axis will show both balanced accuracy (in blue) and disparate impact (in red).

Note that we plot 1 - Disparate Impact, so now a score of 0 indicates no bias.

PYTHON

def plot(x:np.ndarray, x_name:str, y_left:np.ndarray, y_left_name:str, y_right:np.ndarray, y_right_name:str) -> None:
    '''
        Create a matplotlib plot with two y-axes and a single x-axis. 
    '''
    fig, ax1 = plt.subplots(figsize=(10,7)) # create a figure and axis
    ax1.plot(x, y_left) # plot the left y-axis data
    ax1.set_xlabel(x_name, fontsize=16, fontweight='bold') # set the x-axis label
    ax1.set_ylabel(y_left_name, color='b', fontsize=16, fontweight='bold')  # set the left y-axis label
    ax1.xaxis.set_tick_params(labelsize=14) # set the x-axis tick label size
    ax1.yaxis.set_tick_params(labelsize=14) # set the left y-axis tick label size
    ax1.set_ylim(0.5, 0.8) # set the left y-axis limits

    ax2 = ax1.twinx() # create a second y-axis that shares the same x-axis
    ax2.plot(x, y_right, color='r') # plot the right y-axis data
    ax2.set_ylabel(y_right_name, color='r', fontsize=16, fontweight='bold') # set the right y-axis label
    if 'DI' in y_right_name: 
        ax2.set_ylim(0., 0.7) # set the right y-axis limits  if we're plotting disparate impact
    else:
        ax2.set_ylim(-0.25, 0.1)  # set the right y-axis limits if we're plotting 1-DI

    best_ind = np.argmax(y_left) # get the index of the best balanced accuracy
    ax2.axvline(np.array(x)[best_ind], color='k', linestyle=':') # add a vertical line at the best balanced accuracy
    ax2.yaxis.set_tick_params(labelsize=14) # set the right y-axis tick label size
    ax2.grid(True) # add a grid
disp_imp = np.array(val_metrics['disp_imp']) # disparate impact (DI)
disp_imp_err = 1 - disp_imp # calculate 1 - DI
plot(thresh_arr, 'Classification Thresholds',
     val_metrics['bal_acc'], 'Balanced Accuracy',
     disp_imp_err, '1 - DI') # Plot balanced accuracy and 1-DI against the classification thresholds

Interpreting the plot

Answer the following questions:

  1. When the classification threshold is 0.1, what is the (approximate) accuracy and 1-DI score? What about when the classification threshold is 0.5?

  2. If you were developing the model, what classification threshold would you choose based on this graph? Why?

  1. Using a threshold of 0.1, the accuracy is about 0.71 and the 1-DI score is about 0.71. Using a threshold of 0.5, the accuracy is about 0.69 and the 1-DI score is about 0.79.

  2. The optimal accuracy occurs with a threshold of 0.19 (indicated by the dotted vertical line). However, the disparate impact is quite bad at this threshold. Choosing a slightly smaller threshold, e.g., around 0.15, yields similarly high-accuracy and is slightly fairer. However, there’s no “good” outcome here: whenever the accuracy is near-optimal, the 1-DI score is high. If you were the model developer, you might want to consider interventions to improve the accuracy/fairness tradeoff, some of which we discuss below.

If you like, you can plot other metrics, e.g., average odds difference.

In the next cell, we write a function to print out a variety of other metrics. Instead of considering disparate impact directly, we will consider 1 - disparate impact. Recall that a disparate impact of 0 is very bad, and 1 is perfect – thus, considering 1 - disparate impact means that 0 is perfect and 1 is very bad, similar to the other metrics we consider. I.e., all of these metrics have a value of 0 if they are perfectly fair.

We print the value of several metrics here for illustrative purposes (i.e., to see that multiple metrics are not able to be optimized simultaneously). In practice, when evaluating a model it is typical ot choose a single fairness metric to use based on the details of the situation. You can learn more details about the various metrics in the AIF360 documentation.

PYTHON

def describe_metrics(metrics: dict, thresh_arr: np.ndarray) -> None:
    '''
        Given a dictionary of metrics and a list of potential cutoff thresholds, print the best
        threshold (based on 'bal_acc' balanced accuracy dictionary entry) and the corresponding
        values of other metrics at the selected threshold.
    '''
    best_ind = np.argmax(metrics['bal_acc']) # get the index of the best balanced accuracy
    print("Threshold corresponding to Best balanced accuracy: {:6.4f}".format(thresh_arr[best_ind]))
    print("Best balanced accuracy: {:6.4f}".format(metrics['bal_acc'][best_ind]))
    disp_imp_at_best_ind = 1 - metrics['disp_imp'][best_ind] # calculate 1 - DI at the best index
    print("\nCorresponding 1-DI value: {:6.4f}".format(disp_imp_at_best_ind))
    print("Corresponding average odds difference value: {:6.4f}".format(metrics['avg_odds_diff'][best_ind]))
    print("Corresponding statistical parity difference value: {:6.4f}".format(metrics['stat_par_diff'][best_ind]))
    print("Corresponding equal opportunity difference value: {:6.4f}".format(metrics['eq_opp_diff'][best_ind]))
    print("Corresponding Theil index value: {:6.4f}".format(metrics['theil_ind'][best_ind]))

describe_metrics(val_metrics, thresh_arr) # call the function

Test the model

Now that we have used the validation data to select the best threshold, we will evaluate the test the model on the test data.

PYTHON

lr_metrics = test(dataset=dataset_orig_panel19_test,
                       model=lr_orig_panel19,
                       thresh_arr=[thresh_arr[lr_orig_best_ind]]) # call our function "test" with the test data and lr model
describe_metrics(lr_metrics, [thresh_arr[lr_orig_best_ind]]) # print the metrics for the test data

Mitigate bias with in-processing


We will use reweighting as an in-processing step to try to increase fairness. AIF360 has a function that performs reweighting that we will use. If you’re interested, you can look at details about how it works in the documentation.

If you look at the documentation, you will see that AIF360 classifies reweighting as a preprocessing, not an in-processing intervention. Technically, AIF360’s implementation modifies the dataset, not the learning algorithm so it is pre-processing. But, it is functionally equivalent to modifying the learning algorithm’s loss function, so we follow the convention of the fair ML field and call it in-processing.

PYTHON

from aif360.algorithms.preprocessing import Reweighing

PYTHON

# Reweighting is a AIF360 class to reweight the data 
RW = Reweighing(unprivileged_groups=unprivileged_groups,
                privileged_groups=privileged_groups) # create a Reweighing object with the unprivileged and privileged groups
dataset_transf_panel19_train = RW.fit_transform(dataset_orig_panel19_train) # reweight the training data

We’ll also define metrics for the reweighted data and print out the disparate impact of the dataset.

PYTHON

metric_transf_panel19_train = BinaryLabelDatasetMetric(
        dataset_transf_panel19_train, # use train data
        unprivileged_groups=unprivileged_groups, # pass in unprivileged and privileged groups
        privileged_groups=privileged_groups)
explainer_transf_panel19_train = MetricTextExplainer(metric_transf_panel19_train) # create a MetricTextExplainer object

print(explainer_transf_panel19_train.disparate_impact()) # print disparate impact

Then, we’ll train a model, validate it, and evaluate of the test data.

PYTHON

# train
dataset = dataset_transf_panel19_train  # use the reweighted training data
model = make_pipeline(StandardScaler(),
                      LogisticRegression(solver='liblinear', random_state=1)) # model pipeline
fit_params = {'logisticregression__sample_weight': dataset.instance_weights}
lr_transf_panel19 = model.fit(dataset.features, dataset.labels.ravel(), **fit_params) # fit the model

PYTHON

# validate
thresh_arr = np.linspace(0.01, 0.5, 50) # check 50 thresholds between 0.01 and 0.5
val_metrics = test(dataset=dataset_orig_panel19_val,
                   model=lr_transf_panel19,
                   thresh_arr=thresh_arr) # call our function "test" with the validation data and lr model
lr_transf_best_ind = np.argmax(val_metrics['bal_acc']) # get the index of the best balanced accuracy

PYTHON

# plot validation results
disp_imp = np.array(val_metrics['disp_imp']) # get the disparate impact values
disp_imp_err = 1 - np.minimum(disp_imp, 1/disp_imp) # calculate 1 - min(DI, 1/DI)
plot(thresh_arr, # use the classification thresholds as the x-axis
     'Classification Thresholds', 
     val_metrics['bal_acc'],  # plot accuracy on the first y-axis
     'Balanced Accuracy', 
     disp_imp_err, # plot 1 - min(DI, 1/DI) on the second y-axis
     '1 - min(DI, 1/DI)'
     )

PYTHON

describe_metrics(val_metrics, thresh_arr) # describe validation results

Test

PYTHON

lr_transf_metrics = test(dataset=dataset_orig_panel19_test,
                         model=lr_transf_panel19,
                         thresh_arr=[thresh_arr[lr_transf_best_ind]]) # call our function "test" with the test data and lr model
describe_metrics(lr_transf_metrics, [thresh_arr[lr_transf_best_ind]]) # describe test results

We see that the disparate impact score on the test data is better after reweighting than it was originally.

How do the other fairness metrics compare?

Mitigate bias with preprocessing


We will use a method, ThresholdOptimizer, that is implemented in the library Fairlearn. ThresholdOptimizer finds custom thresholds for each demographic group so as to achieve parity in the desired group fairness metric.

We will focus on demographic parity, but feel free to try other metrics if you’re curious on how it does.

The first step is creating the ThresholdOptimizer object. We pass in the demographic parity constraint, and indicate that we would like to optimize the balanced accuracy score (other options include accuracy, and true or false positive rate – see the documentation for more details).

PYTHON

from fairlearn.postprocessing import ThresholdOptimizer # import ThresholdOptimizer

PYTHON

# create a ThresholdOptimizer object
to = ThresholdOptimizer(estimator=model, 
                        constraints="demographic_parity", # set the constraint to demographic parity
                        objective="balanced_accuracy_score", # optimize for balanced accuracy
                        prefit=True) 

Next, we fit the ThresholdOptimizer object to the validation data.

PYTHON

to.fit(dataset_orig_panel19_val.features, dataset_orig_panel19_val.labels, 
       sensitive_features=dataset_orig_panel19_val.protected_attributes[:,0]) # fit the ThresholdOptimizer object

Then, we’ll create a helper function, mini_test to allow us to call the describe_metrics function even though we are no longer evaluating our method as a variety of thresholds.

After that, we call the ThresholdOptimizer’s predict function on the validation and test data, and then compute metrics and print the results.

PYTHON

def mini_test(dataset:MEPSDataset19, preds:np.ndarray) -> dict:
    '''
        Given a dataset and predictions, compute various metrics for the model. Returns a dictionary of the metrics,
        including balanced accuracy, average odds difference, disparate impact, statistical parity difference, equal
        opportunity difference, and theil index.
    '''
    metric_arrs = defaultdict(list)
    dataset_pred = dataset.copy()
    dataset_pred.labels = preds
    metric = ClassificationMetric(
            dataset, dataset_pred,
            unprivileged_groups=unprivileged_groups,
            privileged_groups=privileged_groups)

    # various metrics - can look up what they are on your own
    metric_arrs['bal_acc'].append((metric.true_positive_rate()
                                    + metric.true_negative_rate()) / 2)
    metric_arrs['avg_odds_diff'].append(metric.average_odds_difference())
    metric_arrs['disp_imp'].append(metric.disparate_impact())
    metric_arrs['stat_par_diff'].append(metric.statistical_parity_difference())
    metric_arrs['eq_opp_diff'].append(metric.equal_opportunity_difference())
    metric_arrs['theil_ind'].append(metric.theil_index())
    
    return metric_arrs

PYTHON

# get predictions for validation dataset using the ThresholdOptimizer
to_val_preds = to.predict(dataset_orig_panel19_val.features, 
                          sensitive_features=dataset_orig_panel19_val.protected_attributes[:,0])

# get predictions for test dataset using the ThresholdOptimizer
to_test_preds = to.predict(dataset_orig_panel19_test.features, 
                           sensitive_features=dataset_orig_panel19_test.protected_attributes[:,0])

PYTHON

to_val_metrics = mini_test(dataset_orig_panel19_val, to_val_preds) # compute metrics for the validation set
to_test_metrics = mini_test(dataset_orig_panel19_test, to_test_preds) # compute metrics for the test set

PYTHON

print("Remember, `Threshold corresponding to Best balanced accuracy` is just a placeholder here.")
describe_metrics(to_val_metrics, [0]) # check accuracy (ignore other metrics for now)

PYTHON

print("Remember, `Threshold corresponding to Best balanced accuracy` is just a placeholder here.")

describe_metrics(to_test_metrics, [0]) # check accuracy (ignore other metrics for now)

Scroll up and see how these results compare with the original classifier and with the in-processing technique.

A major difference is that the accuracy is lower, now. In practice, it might be better to use an algorithm that allows a custom tradeoff between the accuracy sacrifice and increased levels of fairness.

We can also see what threshold is being used for each demographic group by examining the interpolated_thresholder_.interpretation_dict property of the ThresholdOptimzer.

PYTHON

threshold_rules_by_group = to.interpolated_thresholder_.interpolation_dict # get the threshold rules by group
threshold_rules_by_group # print the threshold rules by group

Recall that a value of 1 in the Race column corresponds to White people, while a value of 0 corresponds to non-White people.

Due to the inherent randomness of the ThresholdOptimizer, you might get slightly different results than your neighbors. When we ran the previous cell, the output was

{0.0: {'p0': 0.9287205987170348, 'operation0': [>0.5], 'p1': 0.07127940128296517, 'operation1': [>-inf]}, 1.0: {'p0': 0.002549618320610717, 'operation0': [>inf], 'p1': 0.9974503816793893, 'operation1': [>0.5]}}

This tells us that for non-White individuals:

  • If the score is above 0.5, predict 1.

  • Otherwise, predict 1 with probability 0.071

And for White individuals:

  • If the score is above 0.5, predict 1 with probability 0.997

Discuss

What are the pros and cons of improving the model fairness by introducing randomization?

Pros: Randomization can be effective at increasing fairness.

Cons: There is less predictability and explainability in model outcomes. Even though model outputs are fair in aggregate according to a defined group fairness metric, decisions may feel unfair on an individual basis because similar individual (or even the same individual, at different times) are treated unequally. Randomization may not be appropriate in settings (e.g., medical diagnosis) where accuracy is paramount.

Key Points

  • It’s important to consider many dimensions of model performance: a single accuracy score is not sufficient.
  • There is no single definition of “fair machine learning”: different notions of fairness are appropriate in different contexts.
  • Representational harms and stereotypes can be perpetuated by generative AI.
  • The fairness of a model can be improved by using techniques like data reweighting and model postprocessing.

Content from Interpretablility versus explainability


Last updated on 2024-11-13 | Edit this page

Estimated time: 2 minutes

Overview

Questions

  • What are model interpretability and model explainability? Why are they important?
  • How do you choose between interpretable models and explainable models in different contexts?

Objectives

  • Understand and distinguish between explainable machine learning models and interpretable machine learning models.
  • Make informed model selection choices based on the goals of your model.

Key Points

  • Model Explainability vs. Model Interpretability:
    • Interpretability: The degree to which a human can understand the cause of a decision made by a model, crucial for verifying correctness and ensuring compliance.
    • Explainability: The extent to which the internal mechanics of a machine learning model can be articulated in human terms, important for transparency and building trust.
  • Choosing Between Explainable and Interpretable Models:
    • When Transparency is Critical: Use interpretable models when understanding how decisions are made is essential.
    • When Performance is a Priority: Use explainable models when accuracy is more important, leveraging techniques like LIME and SHAP to clarify complex models.
  • Accuracy vs. Complexity:
    • The relationship between model complexity and accuracy is not always linear. Increasing complexity can improve accuracy up to a point but may lead to overfitting, highlighting the gray area in model selection. This is illustrated by the accuracy vs. complexity plot, which shows different models on these axes.

Introduction

In this lesson, we will explore the concepts of interpretability and explainability in machine learning models. For applied scientists, choosing the right model for your research data is critical. Whether you’re working with patient data, environmental factors, or financial information, understanding how a model arrives at its predictions can significantly impact your work.

Interpretability

In the context of machine learning, interpretability is the degree to which a human can understand the cause of a decision made by a model, crucial for verifying correctness and ensuring compliance.

“Interpretable” models: Generally refers to models that are inherently understandable, such as…

  • Linear regression: Examining the coefficients along with confidence intervals (CIs) helps understand the strength and direction of the relationship between features and predictions.
  • Decision trees: Visualizing decision trees allows users to see the rules that lead to specific predictions, clarifying how features interact in the decision-making process.
  • Rule-based classifiers. These models provide clear insights into how input features influence predictions, making it easier for users to verify and trust the outcomes.

However, as we scale up these models (e.g., high-dimensional regression models or random forests), it is important to note that the complexity can increase significantly, potentially making these models less interpretable than their simpler counterparts.

Explainability

The extent to which the internal mechanics of a machine learning model can be articulated in human terms, important for transparency and building trust.

Explainable models: Typical refers to more complex models, such as neural networks or ensemble methods, that may act as black boxes. While these models can deliver high accuracy, they require additional techniques (like LIME and SHAP) to explain their decisions.

Explainability methods preview: Various explainability methods exist to help clarify how complex models work. For instance…

  • LIME (Local Interpretable Model-agnostic Explanations) provides insights into individual predictions by approximating the model locally with a simpler, interpretable model.
  • SHAP (SHapley Additive exPlanations) assigns each feature an importance value for a particular prediction, helping understand the contribution of each feature.
  • Saliency Maps visually highlight which parts of an input (e.g., in images) are most influential for a model’s prediction.

These techniques, which we’ll talk more about in a later episode, bridge the gap between complex models and user understanding, enhancing transparency while still leveraging powerful algorithms.

Accuracy vs. Complexity

The traditional idea that simple models (e.g., regression, decision trees) are inherently interpretable and complex models (neural nets) are truly black-box is increasingly inadequate. Modern interpretable models, such as high-dimensional regression or tree-based methods with hundreds of variables, can be as difficult to understand as neural networks. This leads to a more fluid spectrum of complexity versus accuracy.

The accuracy vs. complexity plot from the AAAI tutorial helps to visualize the continuous relationship between model complexity, accuracy, and interpretability. It showcases that the trade-off is not always straightforward, and some models can achieve a balance between interpretability and strong performance.

This evolving landscape demonstrates that the old clusters of “interpretable” versus “black-box” models break down. Instead, we must evaluate models across the dimensions of complexity and accuracy.

Understanding the trade-off between model complexity and accuracy is crucial for effective model selection. As model complexity increases, accuracy typically improves. However, more complicated models become more difficult to interpret and explain.

Accuracy vs. Complexity Plot
Accuracy vs. Complexity Plot

Discussion of the Plot:

  • X-Axis: Represents model complexity, ranging from simple models (like linear regression) to complex models (like deep neural networks).
  • Y-Axis: Represents accuracy, demonstrating how well each model performs on a given task.

This plot illustrates that while simpler models offer clarity and ease of understanding, they may not effectively capture complex relationships in the data. Conversely, while complex models can achieve higher accuracy, they may sacrifice interpretability, which can hinder trust in their predictions.

Exploring Model Choices

We will analyze a few real-world scenarios and discuss the trade-offs between “interpretable models” (e.g., regression, decision trees, etc.) and “explainable models” (e.g., neural nets).

For each scenario, you’ll consider key factors like accuracy, complexity, and transparency, and answer discussion questions to evaluate the strengths and limitations of each approach.

Here are some of the questions you’ll reflect on during the exercises:

  • What are the advantages of using interpretable models versus explainable (black box) models in the given context?
  • What are the potential drawbacks of each approach?
  • How might the specific goals of the task influence your choice of model?
  • Are there situations where high accuracy justifies the use of less interpretable models?

As you work through these exercises, keep in mind the broader implications of these decisions, especially in fields like healthcare, where model transparency can directly impact trust and outcomes.

Exercise 1: Model Selection for Predicting COVID-19 Progression, a study by Giotta et al.

Scenario:
In the early days of the COVID-19 pandemic, healthcare professionals faced unprecedented challenges in predicting which patients were at higher risk of severe outcomes. Accurate predictions of death or the need for intensive care could guide resource allocation and improve patient care. A study explored the use of various biomarkers to build predictive models, highlighting the importance of both accuracy and transparency in such high-stakes settings.

Objective:
Predict severe outcomes (death or transfer to intensive care) in COVID-19 patients using biomarkers.

Dataset features:
The dataset includes biomarkers from three categories:
- Hematological markers: White blood cell count, neutrophils, lymphocytes, platelets, hemoglobin, etc.
- Biochemical markers: Albumin, bilirubin, creatinine, cardiac troponin, LDH, etc.
- Inflammatory markers: CRP, serum ferritin, interleukins, TNFα, etc.

These features are critical for understanding disease progression and predicting outcomes.

Discussion questions

Compare the advantages

  • What are the advantages of using interpretable models such as decision trees in predicting COVID-19 outcomes?
  • What are the advantages of using black box models such as neural networks in this scenario?

Assess the drawbacks

  • What are the potential drawbacks of using interpretable models like decision trees?
  • What are the potential drawbacks of using black box models in healthcare settings?

Decision-making criteria

  • In what situations might you prioritize an interpretable model over a black box model, and why?
  • Are there scenarios where the higher accuracy of black box models justifies their use despite their lack of transparency?

Compare the advantages

  • Interpretable models: Allow healthcare professionals to understand and trust the model’s decisions, providing clear insights into which biomarkers contribute most to predicting bad outcomes. This transparency is crucial in critical fields such as healthcare, where understanding the decision-making process can inform treatment plans and improve patient outcomes.
  • Black box models: Often provide higher predictive accuracy, which can be crucial for identifying patterns in complex datasets. They can capture non-linear relationships and interactions that simpler models might miss.

Assess the drawbacks

  • Interpretable models: May not capture complex relationships in the data as effectively as black box models, potentially leading to lower predictive accuracy in some cases.
  • Black box models: Can be difficult to interpret, which hinders trust and adoption by medical professionals. Without understanding the model’s reasoning, it becomes challenging to validate its correctness, ensure regulatory compliance, and effectively debug or refine the model.

Decision-making criteria

  • Interpretable models: When transparency, trust, and regulatory compliance are critical, such as in healthcare settings where understanding and validating decisions is essential.
  • Black box models: When the need for high predictive accuracy outweighs the need for transparency, and when supplementary methods for interpreting the model’s output can be employed.

Exercise 2: COVID-19 Diagnosis Using Chest X-Rays, a study by Ucar and Korkmaz

Objective: Diagnose COVID-19 through chest X-rays.

Motivation:

The COVID-19 pandemic has had an unprecedented impact on global health, affecting millions of people worldwide. One of the critical challenges in managing this pandemic is the rapid and accurate diagnosis of infected individuals. Traditional methods, such as the Reverse Transcription Polymerase Chain Reaction (RT-PCR) test, although widely used, have several drawbacks. These tests are time-consuming, require specialized equipment and personnel, and often suffer from low detection rates, necessitating multiple tests to confirm a diagnosis.

In this context, radiological imaging, particularly chest X-rays, has emerged as a valuable tool for COVID-19 diagnosis. Early studies have shown that COVID-19 causes specific abnormalities in chest X-rays, such as ground-glass opacities, which can be used as indicators of the disease. However, interpreting these images requires expertise and time, both of which are in short supply during a pandemic.

To address these challenges, researchers have turned to machine learning techniques…

Dataset Specification: Chest X-ray images

Real-World Impact:

The COVID-19 pandemic highlighted the urgent need for rapid and accurate diagnostic tools. Traditional methods like RT-PCR tests, while effective, are often time-consuming and have variable detection rates. Using chest X-rays for diagnosis offers a quicker and more accessible alternative. By analyzing chest X-rays, healthcare providers can swiftly identify COVID-19 cases, enabling timely treatment and isolation measures. Developing a machine learning method that can quickly and accurately analyze chest X-rays can significantly enhance the speed and efficiency of the healthcare response, especially in areas with limited access to RT-PCR testing.

Discussion Questions:

  1. Compare the Advantages:
    • What are the advantages of using deep neural networks in diagnosing COVID-19 from chest X-rays?
    • What are the advantages of traditional methods, such as genomic data analysis, for COVID-19 diagnosis?
  2. Assess the Drawbacks:
    • What are the potential drawbacks of using deep neural networks for COVID-19 diagnosis from chest X-rays?
    • How do these drawbacks compare to those of traditional methods?
  3. Decision-Making Criteria:
    • In what situations might you prioritize using deep neural networks over traditional methods, and why?
    • Are there scenarios where the rapid availability of X-ray results justifies the use of deep neural networks despite potential drawbacks?
  1. Compare the Advantages:
    • Deep Neural Networks: Provide high accuracy (e.g., 98%) in diagnosing COVID-19 from chest X-rays, offering a quick and non-invasive diagnostic tool. They can handle large amounts of image data and identify complex patterns that might be missed by human eyes.
    • Traditional Methods: Provide detailed and specific diagnostic information by analyzing genomic data and biomarkers, which can be crucial for understanding the virus’s behavior and patient response.
  2. Assess the Drawbacks:
    • Deep Neural Networks: Require large labeled datasets for training, which may not always be available. The models can be seen as “black boxes”, making it challenging to interpret their decisions without additional explainability methods.
    • Traditional Methods: Time-consuming and may have lower detection accuracy. They often require specialized equipment and personnel, leading to delays in diagnosis.
  3. Decision-Making Criteria:
    • Prioritizing Deep Neural Networks: When rapid diagnosis is critical, and chest X-rays are readily available. Useful in large-scale screening scenarios where speed is more critical than the detailed understanding provided by genomic data.
    • Using Traditional Methods: When detailed and specific information about the virus is needed for treatment planning, and when the availability of genomic data and biomarkers is not a bottleneck.

Content from Explainability methods overview


Last updated on 2024-11-13 | Edit this page

Estimated time: 0 minutes

Overview

Questions

  • What are the major categories of explainability methods, and how do they differ?
  • How do you determine which explainability method to use for a specific use case?
  • What are the trade-offs between black-box and white-box approaches to explainability?
  • How do post-hoc explanation methods compare to inherently interpretable models in terms of utility and reliability?

Objectives

  • Understand the key differences between black-box and white-box explanation methods.
  • Explore the trade-offs between post-hoc explainability and inherent interpretability in models.
  • Identify and categorize different explainability techniques based on their scope, model access, and approach.
  • Learn when to apply specific explainability techniques for various machine learning tasks.

Fantastic Explainability Methods and Where to Use Them


We will now take a bird’s-eye view of explainability methods that are widely applied on complex models like neural networks. We will get a sense of when to use which kind of method, and what the tradeoffs between these methods are.

Three axes of use cases for understanding model behavior


When deciding which explainability method to use, it is helpful to define your setting along three axes. This helps in understanding the context in which the model is being used, and the kind of insights you are looking to gain from the model.

Inherently Interpretable vs Post Hoc Explainable

Understanding the tradeoff between interpretability and complexity is crucial in machine learning. Simple models like decision trees, random forests, and linear regression offer transparency and ease of understanding, making them ideal for explaining predictions to stakeholders. In contrast, neural networks, while powerful, lack interpretability due to their complexity. Post hoc explainable techniques can be applied to neural networks to provide explanations for predictions, but it’s essential to recognize that using such methods involves a tradeoff between model complexity and interpretability.

Striking the right balance between these factors is key to selecting the most suitable model for a given task, considering both its predictive performance and the need for interpretability.

_Credits: AAAI 2021 Tutorial on Explaining Machine Learning Predictions: State of the Art, Challenges, Opportunities._
The tradeoff between Interpretability and Complexity

Local vs Global Explanations

Local explanations focus on describing model behavior within a specific neighborhood, providing insights into individual predictions. Conversely, global explanations aim to elucidate overall model behavior, offering a broader perspective. While global explanations may be more comprehensive, they run the risk of being overly complex.

Both types of explanations are valuable for uncovering biases and ensuring that the model makes predictions for the right reasons. The tradeoff between local and global explanations has a long history in statistics, with methods like linear regression (global) and kernel smoothing (local) illustrating the importance of considering both perspectives in statistical analysis.

Black box vs White Box Approaches

Techniques that require access to model internals (e.g., model architecture and model weights) are called “white box” while techniques that only need query access to the model are called “black box”. Even without access to the model weights, black box or top down approaches can shed a lot of light on model behavior. For example, by simply evaluating the model on certain kinds of data, high level biases or trends in the model’s decision making process can be unearthed.

White box approaches use the weights and activations of the model to understand its behavior. These classes or methods are more complex and diverse, and we will discuss them in more detail later in this episode. Some large models are closed-source due to commercial or safety concerns; for example, users can’t get access to the weights of GPT-4. This limits the use of white box explanations for such models.

Classes of Explainability Methods for Understanding Model Behavior


Diagnostic Testing

This is the simplest approach towards explaining model behavior. This involves applying a series of unit tests to the model, where each test is a sample input where you know what the correct output should be. By identifying test examples that break the heuristics the model relies on (called counterfactuals), you can gain insights into the high-level behavior of the model.

Example Methods: Counterfactuals, Unit tests

Pros and Cons: These methods allow for gaining insights into the high-level behavior of the model without the needing access to model weights. This is especially useful with recent powerful closed-source models like GPT-4. One challenge with this approach is that it is hard to identify in advance what heuristics a model may depend on.

Baking interpretability into models

Some recent research has focused on tweaking highly complex models like neural networks, towards making them more interpretable inherently. One such example with language models involves training the model to generate rationales for its prediction, in addition to its original prediction. This approach has gained some traction, and there are even public benchmarks for evaluating the quality of these generated rationales.

Example methods: Rationales with WT5, Older approaches for rationales

Pros and cons: These models hope to achieve the best of both worlds: complex models that are also inherently interpretable. However, research in this direction is still new, and there are no established and reliable approaches for real world applications just yet.

Identifying Decision Rules of the Model:

In this class of methods, we try find a set of rules that generally explain the decision making process of the model. Loosely, these rules would be of the form “if a specific condition is met, then the model will predict a certain class”.

Example methods: Anchors, Universal Adversarial Triggers

Table caption: "Generated anchors for Tabular datasets". Table shows the following rules: for the adult dataset, predict less than 50K if no capital gain or loss and never married. Predict over 50K if country is US, married, and work hours over 45. For RCDV dataset, predict not rearrested if person has no priors, no prison violations, and crime not against property. Predict re-arrested if person is male, black, has 1-5 priors, is not married, and the crime not against property. For the Lending dataset, predict bad loan if FICO score is less than 650. Predict good loan if FICO score is between 650 and 700 and loan amount is between 5400 and 10000.
Example use of anchors (table from Ribeiro et al.)

Pros and cons: Some global rules help find “bugs” in the model, or identify high level biases. But finding such broad coverage rules is challenging. Furthermore, these rules only showcase the model’s weaknesses, but give next to no insight as to why these weaknesses exist.

Visualizing model weights or representations

Just like how a picture tells a thousand words, visualizations can help encapsulate complex model behavior in a simple image. Visualizations are commonly used in explaining neural networks, where the weights or data representations of the model are directly visualized. Many such approaches involve reducing the high-dimensional weights or representations to a 2D or 3D space, using techniques like PCA, tSNE, or UMAP. Alternatively, these visualizations can retain their high dimensional representation, but use color or size to identify which dimensions or neurons are more important.

Example methods: Visualizing attention heatmaps, Weight visualizations, Model activation visualizations

Image shows a grid with 3 rows and 50 columns. Each cell is colored on a scale of -1.5 (white) to 0.9 (dark blue). Darker colors are concentrated in the first row in seemingly-random columns.
Example usage of visualizing attention heatmaps for part-of-speech (POS) identification task using word2vec-encoded vectors. Each cell is a unit in a neural network (each row is a layer and each column is a dimension). Darker colors indicates that a unit is more importance for predictive accuracy (table from Li et al..)

Pros and cons: Gleaning model behaviour from visualizations is very intuitive and user-friendly, and visualizations sometimes have interactive interfaces. However, visualizations can be misleading, especially when high-dimensional vectors are reduced to 2D, leading to a loss of information (crowding issue).

An iconic debate exemplifying the validity of visualizations has centered around attention heatmaps. Research has shown them to be unreliable, and then reliable again. (Check out the titles of these papers!) Thus, visualization can only be used as an additional step in an analysis, and not as a standalone method.

Understanding the impact of training examples

These techniques unearth which training data instances caused the model to generate a specific prediction for a given sample. At a high level, these techniques mathematically identify what training samples that – if removed from the training process – are most influential for causing a particular prediction.

Example methods: Influence functions, Representer point selection

Two images. On the left, several antelope are standing in the background on a grassy field. On the right, several zebra graze in a field in the background, while there is one antelope in the foreground and other antelope in the background.
Example usage of representer point selection. The image on the left is a test image that is misclassified as a deer (the true label is antelope). The image on the right is the most influential training point. We see that this image is labeled “zebra,” but contains both zebras and antelopes. (example adapted from Yeh et al..)

Pros and cons: The insights from these approaches are actionable - by identifying the data responsible for a prediction, it can help correct labels or annotation artifacts in that data. Unfortunately, these methods scale poorly with the size of the model and training data, quickly becoming computationally expensive. Furthermore, even knowing which datapoints had a high influence on a prediction, we don’t know what it was about that datapoint that caused the influence.

Understanding the impact of a single example:

For a single input, what parts of the input were most important in generating the model’s prediction? These methods study the signal sent by various features to the model, and observe how the model reacts to changes in these features.

Example methods: Saliency Maps, LIME/SHAP, Perturbations (Input reduction, Adversarial Perturbations)

These methods can be further subdivided into two categories: gradient-based methods that rely on white-box model access to directly see the impact of changing a single input, and perturbation-based methods that manually perturb an input and re-query the model to see how the prediction changes.

Two rows images (5 images per row). Leftmost column shows two different pictures, each containing a cat and a dog. Remaining columns show the saliency maps using different techniques (VanillaGrad, InteGrad, GuidedBackProp, and SmoothGrad). Each saliency map has red dots (indicated regions that are influential for predicting "dog") and blue dots (influential for predicting "cat"). All methods except GuidedBackProp have good overlap between the respective dots and where the animals appear in the image. SmoothGrad has the most precise mapping.
Example saliency maps. The right 4 columns show the result of different saliency method techniques, where red dots indicate regions that are influential for predicting “dog” and blue dots indicate regions that are influential for predicting “cat”. The image creators argue that their method, SmoothGrad, is most effective at mapping model behavior to images. (Image taken from Smilkov et al.)

Pros and cons: These methods are fast to compute, and flexible in their use across models. However, the insights gained from these methods are not actionable - knowing which part of the input caused the prediction does not highlight why that part caused it. On finding issues in the prediction process, it is also hard to pick up on if there is an underlying issue in the model, or just the specific inputs tested on. Relatedly, these methods can be unstable, and can even be fooled by adversarial examples.

Probing internal representations

As the name suggests, this class of methods aims to probe the internals of a model, to discover what kind of information or knowledge is stored inside the model. Probes are often administered to a specific component of the model, like a set of neurons or layers within a neural network.

Example methods: Probing classifiers, Causal tracing

The phrase "The nurse examined the farmer for injuries because PRONOUN" is shown twice, once with PRONOUN=she and once with PRONOUN=he. Each word is annotated with the importance of three different attention heads. The distribution of which heads are important with each pronoun differs for all words, but especially for nurse and farmer.
Example probe output. The image shows the result from probing three attention heads. We see that gender stereotypes are encoded into the model because the heads that are important for nurse and farmer change depending on the final pronoun. Specifically, Head 5-10 attends to the stereotypical gender assignment while Head 4-6 attends to the anti-stereotypical gender assignment. (Image taken from Vig et al.)

Pros and cons: Probes have shown that it is possible to find highly interpretable components in a complex model, e.g., MLP layers in transformers have been shown to store factual knowledge in a structured manner. However, there is no systematic way of finding interpretable components, and many components may remain elusive to humans to understand. Furthermore, the model components that have been shown to contain certain knowledge may not actually play a role in the model’s prediction.

Is that all?

Nope! We’ve discussed a few of the common explanation techniques, but many others exist. In particular, specialized model architectures often need their own explanation algorithms. For instance, Yuan et al. give an overview of different explanation techniques for graph neural networks (GNNs).

Classifying explanation techniques

For each of the explanation techniques described above, discuss the following with a partner:

  • Does it require black-box or white-box model access?
  • Are the explanations it provides global or local?
  • Is the technique post-hoc or does it rely on inherent interpretability of the model?
Approach Post Hoc or Inherently Interpretable? Local or Global? White Box or Black Box?
Diagnostic Testing Post Hoc Global Black Box
Baking interpretability into models Inherently Interpretable Local White Box
Identifying Decision Rules of the Model Post Hoc Both White Box
Visualizing model weights or representations Post Hoc Global White Box
Understanding the impact of training examples Post Hoc Local White Box
Understanding the impact of a single example Post Hoc Local Both
Probing internal representations of a model Post Hoc Global/Local White Box

What explanation should you use when? There is no simple answer, as it depends upon your goals (i.e., why you need an explanation), who the audience is, the model architecture, and the availability of model internals (e.g., there is no white-box access to ChatGPT unless you work for Open AI!). The next exercise asks you to consider different scenarios and discuss what explanation techniques are appropriate.

Challenge

Think about the following scenarios and suggest which explainability method would be most appropriate to use, and what information could be gained from that method. Furthermore, think about the limitations of your findings.

Note: These are open-ended questions, and there is no correct answer. Feel free to break into discussion groups to discuss the scenarios.

Scenario 1: Suppose that you are an ML engineer working at a tech company. A fast-food chain company consults with you about sentimental analysis based on feedback they collected on Yelp and their survey. You use an open sourced LLM such as Llama-2 and finetune it on the review text data. The fast-food company asks to provide explanations for the model: Is there any outlier review? How does each review in the data affect the finetuned model? Which part of the language in the review indicates that a customer likes or dislikes the food? Can you score the food quality according to the reviews? Does the review show a trend over time? What item is gaining popularity or losing popularity? Q: Can you suggest a few explainability methods that may be useful for answering these questions?

Scenario 2: Suppose that you are a radiologist who analyzes medical images of patients with the help of machine learning models. You use black-box models (e.g., CNNs, Vision Transformers) to complement human expertise and get useful information before making high-stake decisions. Which areas of a medical image most likely explains the output of a black-box? Can we visualize and understand what features are captured by the intermediate components of the black-box models? How do we know if there is a distribution shift? How can we tell if an image is an out-of-distribution example? Q: Can you suggest a few explainability methods that may be useful for answering these questions?

Scenario 3: Suppose that you work on genomics and you just collected samples of single-cell data into a table: each row records gene expression levels, and each column represents a single cell. You are interested in scientific hypotheses about evolution of cells. You believe that only a few genes are playing a role in your study. What exploratory data analysis techniques would you use to examine the dataset? How do you check whether there are potential outliers, irregularities in the dataset? You believe that only a few genes are playing a role in your study. What can you do to find the set of most explanatory genes? How do you know if there is clustering, and if there is a trajectory of changes in the cells? Q: Can you explain the decisions you make for each method you use?

Summary


There are many available explanation techniques and they differ along three dimensions: model access (white-box or black-box), explanation scope (global or local), and approach (inherently interpretable or post-hoc). There’s often no objectively-right answer of which explanation technique to use in a given situation, as the different methods have different tradeoffs.

Content from Explainability methods: deep dive


Last updated on 2024-11-15 | Edit this page

Estimated time: 0 minutes

Overview

Questions

  • How can we identify which parts of an input contribute most to a model’s prediction?
  • What insights can saliency maps, GradCAM, and similar techniques provide about model behavior?
  • What are the strengths and limitations of gradient-based explainability methods?
  • How can probing classifiers help us understand what a model has learned?
  • What are the limitations of probing classifiers, and how can they be addressed?

Objectives

  • Explain how saliency maps and GradCAM work and their applications in understanding model predictions.
  • Introduce GradCAM as a method to visualize the important features used by a model.
  • Understand the concept of probing classifiers and how they assess the representations learned by models.

A Deep Dive into Methods for Understanding Model Behaviour


In the previous section, we scratched the surface of explainability methods, introducing you to the broad classes of methods designed to understand different aspects of a model’s behavior.

Now, we will dive deeper into two widely used methods, each one which answers one key question:

What part of my input causes this prediction?


When a model makes a prediction, we often want to know which parts of the input were most important in generating that prediction. This helps confirm if the model is making its predictions for the right reasons. Sometimes, models use features totally unrelated to the task for their prediction - these are known as ‘spurious correlations’. For example, a model might predict that a picture contains a dog because it was taken in a park, and not because there is actually a dog in the picture.

Saliency Maps are among the most simple and popular methods used towards this end. We will be working with a more sophisticated version of this method, known as GradCAM.

Method and Examples

A saliency map is a kind of visualization - it is a heatmap across the input that shows which parts of the input are most important in generating the model’s prediction. They can be calculated using the gradients of a neural network, or by perturbing the input to any ML model and observing how the model reacts to these perturbations. The key intuition is that if a small change in a part of the input causes a large change in the model’s prediction, then that part of the input is important for the prediction. Gradients are useful in this because they provide a signal towards how much the model’s prediction would change if the input was changed slightly.

For example, in an image classification task, a saliency map can be used to highlight the parts of the image that the model is focusing on to make its prediction. In a text classification task, a saliency map can be used to highlight the words or phrases that are most important for the model’s prediction.

GradCAM is an extension of this idea, which uses the gradients of the final layer of a convolutional neural network to generate a heatmap that highlights the important regions of an image. This heatmap can be overlaid on the original image to visualize which parts of the image are most important for the model’s prediction.

Other variants of this method include Integrated Gradients, SmoothGrad, and others, which are designed to provide more robust and reliable explanations for model predictions. However, GradCAM is a good starting point for understanding how saliency maps work, and is a popularly used approach.

Alternative approaches, which may not directly generate heatmaps, include LIME and SHAP, which are also popular and recommended for further reading.

Limitations and Extensions

Gradient based saliency methods like GradCam are fast to compute, requiring only a handful of backpropagation steps on the model to generate the heatmap. The method is also model-agnostic, meaning it can be applied to any model that can be trained using gradient descent. Additionally, the results obtained from these methods are intuitive and easy to understand, making them useful for explaining model predictions to non-experts.

However, their use is limited to models that can be trained using gradient descent, and have white-box access. It is also difficult to apply these methods to tasks beyond classification, making their application limited with many recent generative models (think LLMs).

Another limitation is that the insights gained from these methods are not actionable - knowing which part of the input caused the prediction does not highlight why that part caused it. On finding issues in the prediction process, it is also hard to pick up on if there is an underlying issue in the model, or just the specific inputs tested on.

What part of my model causes this prediction?


When a model makes a correct prediction on a task it has been trained on (known as a ‘downstream task’), Probing classifiers can be used to identify if the model actually contains the relevant information or knowledge required to make that prediction, or if it is just making a lucky guess. Furthermore, probes can be used to identify the specific components of the model that contain this relevant information, providing crucial insights for developing better models over time.

Method and Examples

A neural network takes its input as a series of vectors, or representations, and transforms them through a series of layers to produce an output. The job of the main body of the neural network is to develop representations that are as useful for the downstream task as possible, so that the final few layers of the network can make a good prediction.

This essentially means that a good quality representation is one that already contains all the information required to make a good prediction. In other words, the features or representations from the model are easily separable by a simple classifier. And that classifier is what we call a ‘probe’. A probe is a simple model that uses the representations of the model as input, and tries to learn the downstream task from them. The probe itself is designed to be too easy to learn the task on its own. This means, that the only way the probe get perform well on this task is if the representations it is given are already good enough to make the prediction.

These representations can be taken from any part of the model. Generally, using representations from the last layer of a neural network help identify if the model even contains the information to make predictions for the downstream task. However, this can be extended further: probing the representations from different layers of the model can help identify where in the model the information is stored, and how it is transformed through the model.

Probes have been frequently used in the domain of NLP, where they have been used to check if language models contain certain kinds of linguistic information. These probes can be designed with varying levels of complexity. For example, simple probes have shown language models to contain information about simple syntactical features like Part of Speech tags, and more complex probes have shown models to contain entire Parse trees of sentences.

Limitations and Extensions

One large challenge in using probes is identifying the correct architectural design of the probe. Too simple, and it may not be able to learn the downstream task at all. Too complex, and it may be able to learn the task even if the model does not contain the information required to make the prediction.

Another large limitation is that even if a probe is able to learn the downstream task, it does not mean that the model is actually using the information contained in the representations to make the prediction. So essentially, a probe can only tell us if a part of the model can make the prediction, not if it does make the prediction.

A new approach known as Causal Tracing addresses this limitation. The objective of this approach is similar to probes: attempting to understand which part of a model contains information relevant to a downstream task. The approach involves iterating through all parts of the model being examined (e.g. all layers of a model), and disrupting the information flow through that part of the model. (This could be as easy as adding some kind of noise on top of the weights of that model component). If the model performance on the downstream task suddenly drops on disrupting a specific model component, we know for sure that that component not only contains the information required to make the prediction, but that the model is actually using that information to make the prediction.

Challenge

Now, it’s time to try implementing these methods yourself! Pick one of the following problems to work on:

It’s time to get your hands dirty now. Good luck, and have fun!

Content from Explainability methods: linear probe


Last updated on 2024-11-19 | Edit this page

Estimated time: 0 minutes

Overview

Questions

  • TODO

Objectives

  • TODO

PYTHON

# Let's start by importing the necessary libraries.

import os
import torch
import logging
import numpy as np
from typing import Tuple
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from datasets import load_dataset, Dataset
from transformers import AutoModel, AutoTokenizer, AutoConfig

logging.basicConfig(level=logging.INFO)
os.environ['TOKENIZERS_PARALLELISM'] = 'false'  # This is needed to avoid a warning from huggingface

Now, let’s set the random seed to ensure reproducibility. Setting random seeds is like setting a starting point for your machine learning adventure. It ensures that every time you train your model, it starts from the same place, using the same random numbers, making your results consistent and comparable.

PYTHON

# Set random seeds for reproducibility - pick any number of your choice to set the seed. We use 42, since that is the answer to everything, after all.
torch.manual_seed(42)
Loading the Dataset

Let’s load our data: the IMDB Movie Review dataset. The dataset contains text reviews and their corresponding sentiment labels (positive or negative). The label 1 corresponds to a positive review, and 0 corresponds to a negative review.

PYTHON

def load_imdb_dataset(keep_samples: int = 100) -> Tuple[Dataset, Dataset, Dataset]:
    '''
    Load the IMDB dataset from huggingface.
    The dataset contains text reviews and their corresponding sentiment labels (positive or negative).
    The label 1 corresponds to a positive review, and 0 corresponds to a negative review.
    :param keep_samples: Number of samples to keep, for faster training.
    :return: train, dev, test datasets. Each can be treated as a dictionary with keys 'text' and 'label'.
    '''
    dataset = load_dataset('imdb')

    # Keep only a subset of the data for faster training
    train_dataset = Dataset.from_dict(dataset['train'].shuffle(seed=42)[:keep_samples])
    dev_dataset = Dataset.from_dict(dataset['test'].shuffle(seed=42)[:keep_samples])
    test_dataset = Dataset.from_dict(dataset['test'].shuffle(seed=42)[keep_samples:2*keep_samples])

    # train_dataset[0] will return {'text': ...., 'label': 0}
    logging.info(f'Loaded IMDB dataset: {len(train_dataset)} training samples, {len(dev_dataset)} dev samples, {len(test_dataset)} test samples.')
    return train_dataset, dev_dataset, test_dataset

PYTHON

train_dataset, dev_dataset, test_dataset = load_imdb_dataset(keep_samples=50)
Loading the Model

We will load a model from huggingface, and use this model to get the embeddings for the probe. We use distilBERT for this example, but feel free to explore other models from huggingface after the exercise.

BERT is a transformer-based model, and is known to perform well on a variety of NLP tasks. The model is pre-trained on a large corpus of text, and can be fine-tuned for specific tasks.

PYTHON

def load_model(model_name: str) -> Tuple[AutoModel, AutoTokenizer]:
    '''
    Load a model from huggingface.
    :param model_name: Check huggingface for acceptable model names.
    :return: Model and tokenizer.
    '''
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    config = AutoConfig.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name, config=config)
    model.config.max_position_embeddings = 128  # Reducing from default 512 to 128 for computational efficiency

    logging.info(f'Loaded model and tokenizer: {model_name} with {model.config.num_hidden_layers} layers, '
                 f'hidden size {model.config.hidden_size} and sequence length {model.config.max_position_embeddings}.')
    return model, tokenizer

PYTHON

# To play around with other models, find a list of models and their model_ids at: https://huggingface.co/models
model, tokenizer = load_model('distilbert-base-uncased') #'bert-base-uncased' has 12 layers and may take a while to process. We'll investigate distilbert instead.

Let’s see what the model’s architecture looks like. How many layers does it have?

PYTHON

print(model)

Let’s see if your answer matches the actual number of layers in the model.

PYTHON

num_layers = model.config.num_hidden_layers
print(f'The model has {num_layers} layers.')
Setting up the Probe

Before we define the probing classifier or probe, let’s set up some utility functions the probe will use. The probe will be trained from hidden representations from a specific layer of the BERT model. The get_embeddings_from_model function will retrieve the intermediate layer representations (also known as embeddings) from a user defined layer number.

The visualize_embeddings method can be used to see what these high dimensional hidden embeddings would look like when converted into a 2D view. The visualization is not intended to be informative in itself, and is only an additional tool used to get a sense of what the inputs to the probing classifier may look like.

PYTHON

def get_embeddings_from_model(model: AutoModel, tokenizer: AutoTokenizer, layer_num: int, data: list[str]) -> torch.Tensor:
    '''
    Get the embeddings from a model.
    :param model: The model to use. This is needed to get the embeddings.
    :param tokenizer: The tokenizer to use. This is needed to convert the data to input IDs.
    :param layer_num: The layer to get embeddings from. 0 is the input embeddings, and the last layer is the output embeddings.
    :param data: The data to get embeddings for. A list of strings.
    :return: The embeddings. Shape is N, L, D, where N is the number of samples, L is the length of the sequence, and D is the dimensionality of the embeddings.
    '''
    logging.info(f'Getting embeddings from layer {layer_num} for {len(data)} samples...')

    # Batch the data for computational efficiency
    batch_size = 32
    batch_num = 1
    for i in range(0, len(data), batch_size):
        batch = data[i:i+batch_size]
        logging.info(f'Getting embeddings for batch {batch_num}...')
        batch_num += 1

        # Tokenize the batch of data
        inputs = tokenizer(batch, return_tensors='pt', padding=True, truncation=True)

        # Get the embeddings from the model
        outputs = model(**inputs, output_hidden_states=True)

        # Get the embeddings for the specific the layer
        embeddings = outputs.hidden_states[layer_num]

        # Concatenate the embeddings from each batch
        if i == 0:
            all_embeddings = embeddings
        else:
            all_embeddings = torch.cat([all_embeddings, embeddings], dim=0)

    logging.info(f'Got embeddings for {len(data)} samples from layer {layer_num}. Shape: {all_embeddings.shape}')
    return all_embeddings

PYTHON

def visualize_embeddings(embeddings: torch.Tensor, labels: list, layer_num: int, save_plot: bool = False) -> None:
    '''
    Visualize the embeddings using t-SNE.
    :param embeddings: The embeddings to visualize. Shape is N, L, D, where N is the number of samples, L is the length of the sequence, and D is the dimensionality of the embeddings.
    :param labels: The labels for the embeddings. A list of integers.
    :return: None
    '''

    # Since we are working with sentiment analysis, which is sentence based task, we can use sentence embeddings.
    # The sentence embeddings are simply the mean of the token embeddings of that sentence.
    sentence_embeddings = torch.mean(embeddings, dim=1)  # N, D

    # Convert to numpy
    sentence_embeddings = sentence_embeddings.detach().numpy()
    labels = np.array(labels)

    # Visualize the embeddings using t-SNE
    tsne = TSNE(n_components=2, random_state=0)
    embeddings_2d = tsne.fit_transform(sentence_embeddings)

    negative_points = embeddings_2d[labels == 0]
    positive_points = embeddings_2d[labels == 1]

    # Plot the embeddings. We want to colour the datapoints by label.
    fig, ax = plt.subplots()
    ax.scatter(negative_points[:, 0], negative_points[:, 1], label='Negative', color='red', marker='o', s=10, alpha=0.7)
    ax.scatter(positive_points[:, 0], positive_points[:, 1], label='Positive', color='blue', marker='o', s=10, alpha=0.7)
    plt.xlabel('t-SNE dimension 1')
    plt.ylabel('t-SNE dimension 2')
    plt.title(f't-SNE of Sentence Embeddings - Layer{layer_num}')
    plt.legend()

    # Save the plot if needed, then display it
    if save_plot:
        plt.savefig(f'tsne_layer_{layer_num}.png')
    plt.show()

    logging.info('Visualized embeddings using t-SNE.')

Now, it’s finally time to define our probe! We set this up as a class, where the probe itself is an object of this class. The class also contains methods used to train and evaluate the probe.

Read through this code block in a bit more detail - from this whole exercise, this part provides you with the most useful takeaways on ways to define and train neural networks!

PYTHON

class Probe():
    def __init__(self, hidden_dim: int = 768, class_size: int = 2)  -> None:
        '''
        Initialize the probe.
        :param hidden_dim: The dimensionality of the hidden layer of the probe.
        :param num_layers: The number of layers in the probe.
        :return: None
        '''

        # The probe is a simple linear classifier, with a hidden layer and an output layer.
        # The input to the probe is the embeddings from the model, and the output is the predicted class.

        # Exercise: Try playing around with the hidden_dim and num_layers to see how it affects the probe's performance.
        # But watch out: if a complex probe performs well on the task, we don't know if the performance
        # is because of the model embeddings, or the probe itself learning the task!

        self.probe = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, class_size),

            # Add more layers here if needed

            # Sigmoid is used to convert the hidden states into a probability distribution over the classes
            torch.nn.Sigmoid()
        )


    def train(self, data_embeddings: torch.Tensor, labels: torch.Tensor, num_epochs: int = 10,
              learning_rate: float = 0.001, batch_size: int = 32) -> None:
        '''
        Train the probe on the embeddings of data from the model.
        :param data_embeddings: A tensor of shape N, L, D, where N is the number of samples, L is the length of the sequence, and D is the dimensionality of the embeddings.
        :param labels: A tensor of shape N, where N is the number of samples. Each element is the label for the corresponding sample.
        :param num_epochs: The number of epochs to train the probe for. An epoch is one pass through the entire dataset.
        :param learning_rate: How fast the probe learns. A hyperparameter.
        :param batch_size: Used to batch the data for computational efficiency. A hyperparameter.
        :return:
        '''

        # Setup the loss function (training objective) for the training process.
        # The cross-entropy loss is used for multi-class classification, and represents the negative log likelihood of the true class.
        criterion = torch.nn.CrossEntropyLoss()

        # Setup the optimization algorithm to update the probe's parameters during training.
        # The Adam optimizer is an extension to stochastic gradient descent, and is a popular choice.
        optimizer = torch.optim.Adam(self.probe.parameters(), lr=learning_rate)

        # Train the probe
        logging.info('Training the probe...')
        for epoch in range(num_epochs):  # Pass over the data num_epochs times

            for i in range(0, len(data_embeddings), batch_size):

                # Iterate through one batch of data at a time
                batch_embeddings = data_embeddings[i:i+batch_size].detach()
                batch_labels = labels[i:i+batch_size]

                # Convert to sentence embeddings, since we are performing a sentence classification task
                batch_embeddings = torch.mean(batch_embeddings, dim=1)  # N, D

                # Get the probe's predictions, given the embeddings from the model
                outputs = self.probe(batch_embeddings)

                # Calculate the loss of the predictions, against the true labels
                loss = criterion(outputs, batch_labels)

                # Backward pass - update the probe's parameters
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        logging.info('Trained the probe.')


    def predict(self, data_embeddings: torch.Tensor, batch_size: int = 32) -> torch.Tensor:
        '''
        Get the probe's predictions on the embeddings from the model, for unseen data.
        :param data_embeddings: A tensor of shape N, L, D, where N is the number of samples, L is the length of the sequence, and D is the dimensionality of the embeddings.
        :param batch_size: Used to batch the data for computational efficiency.
        :return: A tensor of shape N, where N is the number of samples. Each element is the predicted class for the corresponding sample.
        '''

        # Iterate through batches
        for i in range(0, len(data_embeddings), batch_size):

            # Iterate through one batch of data at a time
            batch_embeddings = data_embeddings[i:i+batch_size]

            # Get the probe's predictions
            outputs = self.probe(batch_embeddings)

            # Get the predicted class for each sample
            _, predicted = torch.max(outputs, 1)

            # Concatenate the predictions from each batch
            if i == 0:
                all_predicted = predicted
            else:
                all_predicted = torch.cat([all_predicted, predicted], dim=0)

        return all_predicted


    def evaluate(self, data_embeddings: torch.tensor, labels: torch.tensor, batch_size: int = 32) -> float:
        '''
        Evaluate the probe's performance by testing it on unseen data.
        :param data_embeddings: A tensor of shape N, L, D, where N is the number of samples, L is the length of the sequence, and D is the dimensionality of the embeddings.
        :param labels: A tensor of shape N, where N is the number of samples. Each element is the label for the corresponding sample.
        :return: The accuracy of the probe on the unseen data.
        '''

        # Iterate through batches
        for i in range(0, len(data_embeddings), batch_size):

            # Iterate through one batch of data at a time
            batch_embeddings = data_embeddings[i:i+batch_size]
            batch_labels = labels[i:i+batch_size]

            # Convert to sentence embeddings, since we are performing a sentence classification task
            batch_embeddings = torch.mean(batch_embeddings, dim=1)  # N, D

            # Get the probe's predictions
            with torch.no_grad():
                outputs = self.probe(batch_embeddings)

            # Get the predicted class for each sample
            _, predicted = torch.max(outputs, dim=-1)

            # Concatenate the predictions from each batch
            if i == 0:
                all_predicted = predicted
                all_labels = batch_labels
            else:
                all_predicted = torch.cat([all_predicted, predicted], dim=0)
                all_labels = torch.cat([all_labels, batch_labels], dim=0)

        # Calculate the accuracy of the probe
        correct = (all_predicted == all_labels).sum().item()
        accuracy = correct / all_labels.shape[0]
        logging.info(f'Probe accuracy: {accuracy:.2f}')

        return accuracy

PYTHON

# Initialize the probing classifier (or probe)
probe = Probe()
Analysing the model using Probes

Time to start evaluating the model using our probing tool! Let’s see which layer has most information about sentiment analysis on IMDB. For this, we will train the probe on embeddings from each layer of the model, and see which layer performs the best on the dev set.

PYTHON

layer_wise_accuracies = []
best_probe, best_layer, best_accuracy = None, -1, 0

for layer_num in range(num_layers):
    logging.info(f'\n\nEvaluating representations of layer {layer_num+1}...')

    train_embeddings = get_embeddings_from_model(model, tokenizer, layer_num=layer_num, data=train_dataset['text'])
    dev_embeddings = get_embeddings_from_model(model, tokenizer, layer_num=layer_num, data=dev_dataset['text'])
    train_labels, dev_labels = torch.tensor(train_dataset['label'],  dtype=torch.long), torch.tensor(dev_dataset['label'],  dtype=torch.long)

    # Before training the probe, let's visualize the embeddings using t-SNE.
    # If the layer has information about sentiment analysis, would we see some structure in the embeddings?
    # Compare plots from layers where the probe does poorly, with ones where it does well. What do you notice?
    visualize_embeddings(embeddings=train_embeddings, labels=train_dataset['label'], layer_num=layer_num, save_plot=False)

    # Now, let's train the probe on the embeddings from the model.
    # Feel free to play around with the training hyperparameters, and see what works best for your probe.
    probe = Probe()
    probe.train(data_embeddings=train_embeddings, labels=train_labels,
                num_epochs=5, learning_rate=0.001, batch_size=32)

    # Let's see how well our probe does on a held out dev set
    accuracy = probe.evaluate(data_embeddings=dev_embeddings, labels=dev_labels)
    layer_wise_accuracies.append(accuracy)

    # Keep track of the best probe
    if accuracy > best_accuracy:
        best_probe, best_layer, best_accuracy = probe, layer_num, accuracy

PYTHON

# Seeing a list of accuracies can be hard to interpret. Let's plot the layer-wise accuracies to see which layer is best.
plt.plot(layer_wise_accuracies)
plt.xlabel('Layer')
plt.ylabel('Accuracy')
plt.title('Probe Accuracy by Layer')
plt.grid(alpha=0.3)
plt.show()

Which layer has the best accuracy? What does this tell us about the model?

Let’s go ahead and stress test this. Is the best layer able to predict sentiment for sentences outside the IMDB dataset?

For answering this question, you are the test set! Try to think of challenging sequences for which the model may not be able to predict sentiment.

PYTHON

test_sequences = ['Your sentence here', 'Here is another sentence']
embeddings = get_embeddings_from_model(model=model, tokenizer=tokenizer, layer_num=best_layer, data=test_sequences)
preds = probe.predict(data_embeddings=embeddings)
predictions = ['Positive' if pred == 1 else 'Negative' for pred in preds]
print(f'Predictions for test sequences: {predictions}')

Content from Explainability methods: GradCAM


Last updated on 2024-07-03 | Edit this page

Estimated time: 0 minutes

Overview

Questions

  • TODO

Objectives

  • TODO

PYTHON

# Let's begin by installing the grad-cam package - this will significantly simplify our implementation
!pip install grad-cam

PYTHON

# Packages to download test images
import requests

# Packages to view and process images
import cv2
import numpy as np
from PIL import Image
from google.colab.patches import cv2_imshow

# Packages to load the model
import torch
from torchvision.models import resnet50

# GradCAM Packaes
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image

PYTHON

device = 'gpu' if torch.cuda.is_available() else 'cpu'
Load Model

We’ll load the ResNet-50 model from torchvision. This model is pre-trained on the ImageNet dataset, which contains 1.2 million images across 1000 classes. ResNet-50 is popular model that is a type of convolutional neural network. You can learn more about it here: https://pytorch.org/hub/pytorch_vision_resnet/

PYTHON

model = resnet50(pretrained=True).to(device).eval()
Load Test Image

PYTHON

# Let's first take a look at the image, which we source from the GradCAM package

url = "https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/master/examples/both.png"
Image.open(requests.get(url, stream=True).raw)

PYTHON

# Cute, isn't it? Do you prefer dogs or cats?

# We will need to convert the image into a tensor to feed it into the model.
# Let's create a function to do this for us.
def load_image(url):
    rgb_img = np.array(Image.open(requests.get(url, stream=True).raw))
    rgb_img = np.float32(rgb_img) / 255
    input_tensor = preprocess_image(rgb_img).to(device)
    return input_tensor, rgb_img

PYTHON

input_tensor, rgb_image = load_image(url)

Grad-CAM Time!

PYTHON

# Let's start by selecting which layers of the model we want to use to generate the CAM.
# For that, we will need to inspect the model architecture.
# We can do that by simply printing the model object.
print(model)

Here we want to interpret what the model as a whole is doing (not what a specific layer is doing). That means that we want to use the embeddings of the last layer before the final classification layer. This is the layer that contains the information about the image encoded by the model as a whole.

Looking at the model, we can see that the last layer before the final classification layer is layer4.

PYTHON

target_layers = [model.layer4]

We also want to pick a label for the CAM - this is the class we want to visualize the activation for. Essentially, we want to see what the model is looking at when it is predicting a certain class.

Since ResNet was trained on the ImageNet dataset with 1000 classes, let’s get an indexed list of those classes. We can then pick the index of the class we want to visualize.

PYTHON

imagenet_categories_url = \
     "https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
labels = eval(requests.get(imagenet_categories_url).text)
labels

Well, that’s a lot! To simplify things, we have already picked out the indices of a few interesting classes.

  • 157: Siberian Husky
  • 162: Beagle
  • 245: French Bulldog
  • 281: Tabby Cat
  • 285: Egyptian cat
  • 360: Otter
  • 537: Dog Sleigh
  • 799: Sliding Door
  • 918: Street Sign

PYTHON

# Specify the target class for visualization here. If you set this to None, the class with the highest score from the model will automatically be used.
visualized_class_id = 245

PYTHON

def viz_gradcam(model, target_layers, class_id):

  if class_id is None:
    targets = None
  else:
    targets = [ClassifierOutputTarget(class_id)]

  cam_algorithm = GradCAM
  with cam_algorithm(model=model, target_layers=target_layers) as cam:
      grayscale_cam = cam(input_tensor=input_tensor,
                          targets=targets)

      grayscale_cam = grayscale_cam[0, :]

      cam_image = show_cam_on_image(rgb_image, grayscale_cam, use_rgb=True)
      cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)

  cv2_imshow(cam_image)

Finally, we can start visualizing! Let’s begin by seeing what parts of the image the model looks at to make its most confident prediction.

PYTHON

viz_gradcam(model=model, target_layers=target_layers, class_id=None)

Interesting, it looks like the model totally ignores the cat and makes a prediction based on the dog. If we set the output class to “French Bulldog” (class_id=245), we see the same visualization - meaning that the model is indeed looking at the correct part of the image to make the correct prediction.

Let’s see what the heatmap looks like when we force the model to look at the cat.

PYTHON

viz_gradcam(model=model, target_layers=target_layers, class_id=281)

The model is indeed looking at the cat when asked to predict the class “Tabby Cat” (class_id=281)! But why is it still predicting the dog? Well, the model was trained on the ImageNet dataset, which contains a lot of images of dogs and cats. The model has learned that the dog is a better indicator of the class “Tabby Cat” than the cat itself.

Let’s see another example of this. The image has not only a dog and a cat, but also a items in the background. Can the model correctly identify the door?

PYTHON

viz_gradcam(model=model, target_layers=target_layers, class_id=799)

It can! However, it seems to also think of the shelf behind the dog as a door.

Let’s try an unrelated object now. Where in the image does the model see a street sign?

PYTHON

viz_gradcam(model=model, target_layers=target_layers, class_id=918)

Looks like our analysis has revealed a shortcoming of the model! It seems to percieve cats and street signs similarly.

Ideally, when the target class is some unrelated object, a good model will look at no significant part of the image. For example, the model does a good job with the class for Dog Sleigh.

PYTHON

viz_gradcam(model=model, target_layers=target_layers, class_id=537)

Explaining model predictions though visualization techniques like this can be very subjective and prone to error. However, this still provides some degree of insight a completely black box model would not provide.

Spend some time playing around with different classes and seeing which part of the image the model looks at. Feel free to play around with other base images as well. Have fun!

Content from Estimating model uncertainty


Last updated on 2024-11-15 | Edit this page

Estimated time: 15 minutes

Overview

Questions

  • What is model uncertainty, and how can it be categorized?
  • How do uncertainty estimation methods intersect with OOD detection methods?
  • What are the computational challenges of estimating model uncertainty?
  • When is uncertainty estimation useful, and what are its limitations?
  • Why is OOD detection often preferred over traditional uncertainty estimation techniques in modern applications?

Objectives

  • Define and distinguish between aleatoric and epistemic uncertainty in machine learning models.
  • Explore common techniques for estimating aleatoric and epistemic uncertainty.
  • Understand why OOD detection has become a widely adopted approach in many real-world applications.
  • Compare and contrast the goals and computational costs of uncertainty estimation and OOD detection.
  • Summarize when and where different uncertainty estimation methods are most useful.

Estimating model uncertainty

Understanding how confident a model is in its predictions is a valuable tool for building trustworthy AI systems, especially in high-stakes settings like healthcare or autonomous vehicles. Model uncertainty estimation focuses on quantifying the model’s confidence and is often used to identify predictions that require further review or caution.

Model uncertainty can be divided into two categories:

  • Aleatoric uncertainty: Inherent noise in the data (e.g., overlapping classes) that cannot be reduced, even with more data.
  • Epistemic uncertainty: Gaps in the model’s knowledge about the data distribution, which can be reduced by using more data or improved models.

Common techniques for estimating aleatoric uncertainty

Aleatoric uncertainty arises from the data itself. Methods to estimate it include:

  • Predictive variance in regression models: Outputs the variance of the predicted value, reflecting the noise in the data. For instance, in a regression task predicting house prices, predictive variance highlights how much randomness exists in the relationship between input features (like square footage) and price.
  • Heteroscedastic models: Use specialized loss functions that allow the model to predict the noise level in the data directly. These models are particularly critical in fields like robotics, where sensor noise varies significantly depending on environmental conditions. For example, a robot navigating in bright daylight versus dim lighting conditions may experience vastly different levels of noise in its sensor inputs, and heteroscedastic models can help account for this variability.
  • Data augmentation and perturbation analysis: Assess variability in predictions by adding noise to the input data and observing how much the model’s outputs change. A highly sensitive change in predictions may indicate underlying noise or instability in the data. For instance, in image classification, augmenting training data with synthetic noise can help the model better handle real-world imperfections like motion blur or occlusions.

Common techniques for estimating epistemic uncertainty

Epistemic uncertainty arises from the model’s lack of knowledge about certain regions of the data space. Techniques to estimate it include:

  • Monte Carlo dropout: In this method, dropout (a regularization technique that randomly disables some neurons) is applied during inference, and multiple forward passes are performed for the same input. The variability in the outputs across these passes gives an estimate of uncertainty. Intuitively, each forward pass simulates a slightly different version of the model, akin to an ensemble. If the model consistently predicts similar outputs despite dropout, it is confident; if predictions vary widely, the model is uncertain about that input.
  • Bayesian neural networks: These networks incorporate probabilistic layers to model uncertainty directly in the weights of the network. Instead of assigning a single deterministic weight to each connection, Bayesian neural networks assign distributions to these weights, reflecting the uncertainty about their true values. During inference, these distributions are sampled multiple times to generate predictions, which naturally include uncertainty estimates. While Bayesian neural networks are theoretically rigorous and align well with the goal of epistemic uncertainty estimation, they are computationally expensive and challenging to scale for large datasets or deep architectures. This is because calculating or approximating posterior distributions over all parameters becomes intractable as model size grows. To address this, methods like variational inference or Monte Carlo sampling are often used, but these approximations can introduce inaccuracies, making Bayesian approaches less practical for many modern applications. Despite these challenges, Bayesian neural networks remain valuable for research contexts where precise uncertainty quantification is needed or in domains where computational resources are less of a concern.
  • Ensemble models: These involve training multiple models on the same data, each starting with different initializations or random seeds. The ensemble’s predictions are aggregated, and the variance in their outputs reflects uncertainty. This approach works well because different models often capture different aspects of the data. For example, if all models agree, the prediction is confident; if they disagree, there is uncertainty. Ensembles are effective but computationally expensive, as they require training and evaluating multiple models.
  • Out-of-distribution detection: Identifies inputs that fall significantly outside the training distribution, flagging areas where the model’s predictions are unreliable. Many OOD methods produce continuous scores, such as Mahalanobis distance or energy-based scores, which measure how novel or dissimilar an input is from the training data. These scores can be interpreted as a form of epistemic uncertainty, providing insight into how unfamiliar an input is. However, OOD detection focuses on distinguishing ID from OOD inputs rather than offering confidence estimates for predictions on ID inputs.

Method selection summary table

Understanding size categories in table

To help guide method selection, here are rough definitions for model size, data size, and compute requirements used in the table:

Model size

  • Small: Fewer than 10M parameters (e.g., logistic regression, LeNet).
  • Medium: 10M–100M parameters (e.g., ResNet-50, BERT-base).
  • Large: More than 100M parameters (e.g., GPT-3, Vision Transformers).

Data size

  • Small: Fewer than 10,000 samples (e.g., materials science datasets).
  • Medium: 10,000–1M samples (e.g., ImageNet).
  • Large: More than 1M samples (e.g., Common Crawl, LAION-5B).

Compute time (approximate)

  • Low: Suitable for standard CPU or single GPU, training/inference in minutes to an hour.
  • Medium: Requires a modern GPU, training/inference in hours to a day.
  • High: Requires multiple GPUs/TPUs or distributed setups, training/inference in days to weeks.
Method Type of uncertainty Key strengths Key limitations Common use cases/domains Model size restrictions Data size restrictions Compute time (approx.)
Predictive variance Aleatoric Simple, intuitive for regression tasks Limited to regression problems; doesn’t address epistemic uncertainty Regression tasks like house price prediction, energy consumption modeling. Small to medium Small to medium Very low (single pass)
Heteroscedastic models Aleatoric Models variable noise across inputs Requires specialized architectures or loss functions Robotics (handling sensor noise), object detection in noisy environments. Medium to large (depends on task) Medium Medium (depends on task complexity)
Monte Carlo dropout Epistemic Easy to implement in existing neural networks Computationally expensive due to multiple forward passes Medical diagnosis models, small-scale deep learning tasks requiring uncertainty estimates. Medium Medium to large High (scales with forward passes)
Bayesian neural nets Epistemic Rigorous probabilistic foundation Computationally prohibitive for large models/datasets Research-intensive applications like materials science, drug discovery. Small to medium (challenging to scale) Small to medium Very high (depends on sampling)
Ensemble models Epistemic Effective and robust; captures diverse uncertainties Resource-intensive; requires training multiple models Financial risk assessment, autonomous systems, and materials property prediction. Medium to large Medium to large Very high (training multiple models)
OOD detection Epistemic Efficient, scalable, excels at rejecting anomalous inputs Limited to identifying OOD inputs, not fine-grained uncertainty NLP (e.g., chatbot queries), computer vision (e.g., detecting novel objects in autonomous driving). Medium to large Small to large Low to medium (scales efficiently)

Why is OOD detection widely adopted?

Among epistemic uncertainty methods, OOD detection has become a widely adopted approach in real-world applications due to its ability to efficiently identify inputs that fall outside the training data distribution, where predictions are inherently unreliable. Many OOD detection techniques produce continuous scores that quantify the novelty or dissimilarity of inputs, which can be interpreted as a form of uncertainty. This makes OOD detection not only effective at rejecting anomalous inputs but also useful for prioritizing inputs based on their predicted risk.

For example, in autonomous vehicles, OOD detection can help flag unexpected scenarios (e.g., unusual objects on the road) in near real-time, enabling safer decision-making. Similarly, in NLP, OOD methods are used to identify queries or statements that deviate from a model’s training corpus, such as out-of-context questions in a chatbot system. In the next couple of episodes, we’ll see how to implement various OOD strategies.

Summary

While uncertainty estimation provides a broad framework for understanding model confidence, different methods are suited for specific types of uncertainty and use cases. OOD detection stands out as the most practical approach for handling epistemic uncertainty in modern applications, thanks to its efficiency and ability to reject anomalous inputs. Together, these methods form a complementary toolkit for building trustworthy AI systems.

Content from OOD detection: overview


Last updated on 2024-11-15 | Edit this page

Estimated time: 0 minutes

Overview

Questions

  • What are out-of-distribution (OOD) data, and why is detecting them important in machine learning models?
  • What are threshold-based methods, and how do they help detect OOD data?

Objectives

  • Understand the concept of out-of-distribution data and its implications for machine learning models.
  • Learn the principles behind threshold-based OOD detection methods.

Introduction to Out-of-Distribution (OOD) Data

What is OOD data?

Out-of-distribution (OOD) data refers to data that significantly differs from the training data on which a machine learning model was built, i.e., the in-distribution (ID). For example, the image below compares the training data distribution of CIFAR-10, a popular dataset used for image classification, with the vastly broader and more diverse distribution of images found on the internet:

OpenAI: CIFAR-10 training distribution vs. internet
OpenAI: CIFAR-10 training distribution vs. internet

CIFAR-10 contains 60,000 images across 10 distinct classes (e.g., airplanes, dogs, trucks), with carefully curated examples for each class. However, the internet features an essentially infinite variety of images, many of which fall outside these predefined classes or include unseen variations (e.g., new breeds of dogs or novel vehicle designs). This contrast highlights the challenges models face when they encounter data that significantly differs from their training distribution.

How OOD data manifests in ML pipelines

The difference between in-distribution (ID) and OOD data can arise from:

  • Semantic shift: The OOD sample belongs to a class that was not present during training.
  • Covariate shift: The OOD sample comes from a domain where the input feature distribution is drastically different from the training data.

Why does OOD data matter?

Models trained on a specific distribution might make incorrect predictions on OOD data, leading to unreliable outputs. In critical applications (e.g., healthcare, autonomous driving), encountering OOD data without proper handling can have severe consequences.

Ex1: Tesla crashes into jet

In April 2022, a Tesla Model Y crashed into a $3.5 million private jet at an aviation trade show in Spokane, Washington, while operating on the “Smart Summon” feature. The feature allows Tesla vehicles to autonomously navigate parking lots to their owners, but in this case, it resulted in a significant mishap. - The Tesla was summoned by its owner using the Tesla app, which requires holding down a button to keep the car moving. The car continued to move forward even after making contact with the jet, pushing the expensive aircraft and causing notable damage. - The crash highlighted several issues with Tesla’s Smart Summon feature, particularly its object detection capabilities. The system failed to recognize and appropriately react to the presence of the jet, a problem that has been observed in other scenarios where the car’s sensors struggle with objects that are lifted off the ground or have unusual shapes.

Ex2: IBM Watson for Oncology

Around a decade ago, the excitement surrounding AI in healthcare often exceeded its actual capabilities. In 2016, IBM launched Watson for Oncology, an AI-powered platform for treatment recommendations, to much public enthusiasm. However, it soon became apparent that the system was both costly and unreliable, frequently generating flawed advice while operating as an opaque “black box. IBM Watson for Oncology faced several issues due to OOD data. The system was primarily trained on data from Memorial Sloan Kettering Cancer Center (MSK), which did not generalize well to other healthcare settings. This led to the following problems:

  1. Unsafe recommendations: Watson for Oncology provided treatment recommendations that were not safe or aligned with standard care guidelines in many cases outside of MSK. This happened because the training data was not representative of the diverse medical practices and patient populations in different regions
  2. Bias in training data: The system’s recommendations were biased towards the practices at MSK, failing to account for different treatment protocols and patient needs elsewhere. This bias is a classic example of an OOD issue, where the model encounters data (patients and treatments) during deployment that significantly differ from its training data

By 2022, IBM had taken Watson for Oncology offline, marking the end of its commercial use.

Ex3: Doctors using GPT3

Misdiagnosis and inaccurate medical advice

In various studies and real-world applications, GPT-3 has been shown to generate inaccurate medical advice when faced with OOD data. This can be attributed to the fact that the training data, while extensive, does not cover all possible medical scenarios and nuances, leading to hallucinations or incorrect responses when encountering unfamiliar input.

A study published by researchers at Stanford found that GPT-3, even when using retrieval-augmented generation, provided unsupported medical advice in about 30% of its statements. For example, it suggested the use of a specific dosage for a defibrillator based on monophasic technology, while the cited source only discussed biphasic technology, which operates differently.

Fake medical literature references

Another critical OOD issue is the generation of fake or non-existent medical references by LLMs. When LLMs are prompted to provide citations for their responses, they sometimes generate references that sound plausible but do not actually exist. This can be particularly problematic in academic and medical contexts where accurate sourcing is crucial.

In evaluations of GPT-3’s ability to generate medical literature references , it was found that a significant portion of the references were either entirely fabricated or did not support the claims being made. This was especially true for complex medical inquiries that the model had not seen in its training data.

Recognizing OOD data in your work

Think of a scenario from your field of work or study where encountering out-of-distribution (OOD) data would be problematic. Consider the following:

  • What would be the in-distribution (ID) data in that context?
  • What might constitute OOD data, and how could it impact the results or outputs of your system/model?

Share your example with the group. Discuss any strategies currently used or that could be used to mitigate the challenges posed by OOD data in your example.

Detecting and handling OOD data

Given the problems posed by OOD data, a reliable model should identify such instances, and then:

  1. Reject them during inference
  2. Ideally, hand these OOD instances to a model trained on a more similar distribution (an in-distribution).

The second step is much more complicated/involved since it requires matching OOD data to essentially an infinite number of possible classes. For the current scope of this workshop, we will focus on just the first step.

Threshold-based

How can we ensure our models do not perform poorly in the presence of OOD data? Over the past several years, there have been a wide assortment of new methods developed to tackle this task. The central idea behind all of these methods is to define a threshold on a certain score or confidence measure, beyond which the data point is considered out-of-distribution. Typically, these scores are derived from the model’s output probabilities or other statistical measures of uncertainty. There are two general classes of threshold-based OOD detection methods: output-based and distance-based.

1) Output-based thresholds

Output-based Out-of-Distribution (OOD) detection refers to methods that determine whether a given input is out-of-distribution based on the output of a trained model. These methods typically analyze the model’s confidence scores, energy scores, or other output metrics to identify data points that are unlikely to belong to the distribution the model was trained on. The main approaches within output-based OOD detection include:

  • Softmax scores: The softmax output of a neural network represents the predicted probabilities for each class. A common threshold-based method involves setting a confidence threshold, and if the maximum softmax score of an instance falls below this threshold, it is flagged as OOD.
  • Energy: The energy-based method also uses the network’s output but measures the uncertainty in a more nuanced way by calculating an energy score. The energy score typically captures the confidence more robustly, especially in high-dimensional spaces, and can be considered a more general and reliable approach than just using softmax probabilities.

2) Distance-based thresholds

Distance-based methods calculate the distance of an instance from the distribution of training data features learned by the model. If the distance is beyond a certain threshold, the instance is considered OOD. Common distance-based approaches include:

  • Mahalanobis distance: This method calculates the Mahalanobis distance of a data point from the mean of the training data distribution. A high Mahalanobis distance indicates that the instance is likely OOD.
  • K-nearest neighbors (KNN): This method involves computing the distance to the k-nearest neighbors in the training data. If the average distance to these neighbors is high, the instance is considered OOD.

We will focus on output-based methods (softmax and energy) in the next episode and then do a deep dive into distance-based methods in a later next episode.

Key Points

  • Out-of-distribution (OOD) data significantly differs from training data and can lead to unreliable model predictions.
  • Threshold-based methods use model outputs or distances in feature space to detect OOD instances by defining a score threshold.

Content from OOD detection: softmax


Last updated on 2024-11-15 | Edit this page

Estimated time: 0 minutes

Overview

Questions

  • What is softmax-based out-of-distribution (OOD) detection, and how does it work?
  • What are the strengths and limitations of using softmax scores for OOD detection?
  • How do threshold choices affect the performance of softmax-based OOD detection?
  • How can we assess and improve softmax-based OOD detection through evaluation metrics and visualization?

Objectives

  • Understand how softmax scores can be leveraged for OOD detection.
  • Explore the advantages and drawbacks of using softmax-based methods for OOD detection.
  • Learn how to visualize and interpret softmax-based OOD detection performance using tools like PCA and probability density plots.
  • Investigate the impact of thresholds on the trade-offs between detecting OOD and retaining in-distribution data.
  • Build a foundation for understanding more advanced output-based OOD detection methods, such as energy-based detection.

Example 1: Softmax scores

Softmax-based methods are among the most widely used techniques for out-of-distribution (OOD) detection, leveraging the probabilistic outputs of a model to differentiate between in-distribution (ID) and OOD data. These methods are inherently tied to models employing a softmax activation function in their final layer, such as logistic regression or neural networks with a classification output layer. Softmax produces a normalized probability distribution across the classes, which can then be thresholded to identify OOD instances.

In this lesson, we will train a logistic regression model to classify images from the Fashion MNIST dataset and explore how its softmax outputs can signal whether a given input belongs to the ID classes (e.g., T-shirts or pants) or is OOD (e.g., sandals). While softmax is most naturally applied in models with a logistic activation, alternative approaches, such as applying softmax-like operations post hoc to models with different architectures, are occasionally used. However, these alternatives are less common and may require additional considerations, such as recalibrating the output scores.

By focusing on logistic regression, we aim to illustrate the fundamental principles of softmax-based OOD detection in a simple and interpretable context before extending these ideas to more complex architectures.

PYTHON

# some settings I'm playing around with when designing this lesson
verbose = False
alpha=0.2
max_iter = 10 # increase after testing phase
n_epochs = 10 # increase after testing phase

Prepare the ID (train and test) and OOD data

  • ID = T-shirts/Blouses, Pants
  • OOD = any other class. For Illustrative purposes, we’ll focus on images of sandals as the OOD class.

PYTHON

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from keras.datasets import fashion_mnist

def prep_ID_OOD_datasests(ID_class_labels, OOD_class_labels):
    # Load Fashion MNIST dataset
    (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
    
    # Prepare OOD data: Sandals = 5
    ood_filter = np.isin(test_labels, OOD_class_labels)
    ood_data = test_images[ood_filter]
    ood_labels = test_labels[ood_filter]
    print(f'ood_data.shape={ood_data.shape}')
    
    # Filter data for T-shirts (0) and Trousers (1) as in-distribution
    train_filter = np.isin(train_labels, ID_class_labels)
    test_filter = np.isin(test_labels, ID_class_labels)
    
    train_data = train_images[train_filter]
    train_labels = train_labels[train_filter]
    print(f'train_data.shape={train_data.shape}')
    
    test_data = test_images[test_filter]
    test_labels = test_labels[test_filter]
    print(f'test_data.shape={test_data.shape}')

    return train_data, test_data, ood_data, train_labels, test_labels, ood_labels


def plot_data_sample(train_data, ood_data):
    """
    Plots a sample of in-distribution and OOD data.

    Parameters:
    - train_data: np.array, array of in-distribution data images
    - ood_data: np.array, array of out-of-distribution data images

    Returns:
    - fig: matplotlib.figure.Figure, the figure object containing the plots
    """
    fig = plt.figure(figsize=(10, 4))
    for i in range(5):
        plt.subplot(2, 5, i + 1)
        plt.imshow(train_data[i], cmap='gray')
        plt.title("In-Dist")
        plt.axis('off')
    for i in range(5):
        plt.subplot(2, 5, i + 6)
        plt.imshow(ood_data[i], cmap='gray')
        plt.title("OOD")
        plt.axis('off')
    
    return fig

PYTHON

train_data, test_data, ood_data, train_labels, test_labels, ood_labels = prep_ID_OOD_datasests([0,1], [5])
fig = plot_data_sample(train_data, ood_data)
fig.savefig('../images/OOD-detection_image-data-preview.png', dpi=300, bbox_inches='tight')
plt.show()
Preview of image dataset
Preview of image dataset

Visualizing OOD and ID data

PCA

PCA visualization can provide insights into how well a model is separating ID and OOD data. If the OOD data overlaps significantly with ID data in the PCA space, it might indicate that the model could struggle to correctly identify OOD samples.

Focus on Linear Relationships: PCA is a linear dimensionality reduction technique. It assumes that the directions of maximum variance in the data can be captured by linear combinations of the original features. This can be a limitation when the data has complex, non-linear relationships, as PCA may not capture the true structure of the data. However, if you’re using a linear model (as we are here), PCA can be more appropriate for visualizing in-distribution (ID) and out-of-distribution (OOD) data because both PCA and linear models operate under linear assumptions. PCA will effectively capture the main variance in the data as seen by the linear model, making it easier to understand the decision boundaries and how OOD data deviates from the ID data within those boundaries.

PYTHON

# Flatten images for PCA and logistic regression
train_data_flat = train_data.reshape((train_data.shape[0], -1))
test_data_flat = test_data.reshape((test_data.shape[0], -1))
ood_data_flat = ood_data.reshape((ood_data.shape[0], -1))

print(f'train_data_flat.shape={train_data_flat.shape}')
print(f'test_data_flat.shape={test_data_flat.shape}')
print(f'ood_data_flat.shape={ood_data_flat.shape}')

PYTHON

# Perform PCA to visualize the first two principal components
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
train_data_pca = pca.fit_transform(train_data_flat)
test_data_pca = pca.transform(test_data_flat)
ood_data_pca = pca.transform(ood_data_flat)

# Plotting PCA components
plt.figure(figsize=(10, 6))
scatter1 = plt.scatter(train_data_pca[train_labels == 0, 0], train_data_pca[train_labels == 0, 1], c='blue', label='T-shirt/top (ID)', alpha=0.5)
scatter2 = plt.scatter(train_data_pca[train_labels == 1, 0], train_data_pca[train_labels == 1, 1], c='red', label='Pants (ID)', alpha=0.5)
scatter3 = plt.scatter(ood_data_pca[:, 0], ood_data_pca[:, 1], c='green', label='Sandals (OOD)', edgecolor='k')

# Create a single legend for all classes
plt.legend(handles=[scatter1, scatter2, scatter3], loc="upper right")
plt.xlabel('First Principal Component')
plt.ylabel('Second Principal Component')
plt.title('PCA of In-Distribution and OOD Data')
plt.savefig('../images/OOD-detection_PCA-image-dataset.png', dpi=300, bbox_inches='tight')
plt.show()

PCA visualization From this plot, we see that sandals are more likely to be confused as T-shirts than pants. It also may be surprising to see that these data clouds overlap so much given their semantic differences. Why might this be?

  • Over-reliance on linear relationships: Part of this has to do with the fact that we’re only looking at linear relationships and treating each pixel as its own input feature, which is usually never a great idea when working with image data. In our next example, we’ll switch to the more modern approach of CNNs.
  • Semantic gap != feature gap: Another factor of note is that images that have a wide semantic gap may not necessarily translate to a wide gap in terms of the data’s visual features (e.g., ankle boots and bags might both be small, have leather, and have zippers). Part of an effective OOD detection scheme involves thinking carefully about what sorts of data contanimations may be observed by the model, and assessing how similar these contaminations may be to your desired class labels. ## Train and evaluate model on ID data

PYTHON

# Train a logistic regression classifier
model = LogisticRegression(max_iter=max_iter, solver='lbfgs', multi_class='multinomial').fit(train_data_flat, train_labels)

Before we worry about the impact of OOD data, let’s first verify that we have a reasonably accurate model for the ID data.

PYTHON

# Evaluate the model on in-distribution data
in_dist_preds = model.predict(test_data_flat)
in_dist_accuracy = accuracy_score(test_labels, in_dist_preds)
print(f'In-Distribution Accuracy: {in_dist_accuracy:.2f}')

PYTHON

from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay

# Generate and display confusion matrix
cm = confusion_matrix(test_labels, in_dist_preds, labels=[0, 1])
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['T-shirt/top', 'Pants'])
disp.plot(cmap=plt.cm.Blues)
plt.savefig('../images/OOD-detection_ID-confusion-matrix.png', dpi=300, bbox_inches='tight')
plt.show()
ID confusion matrix
ID confusion matrix

How does our model view OOD data?

A basic question we can start with is to ask, on average, how are OOD samples classified? Are they more likely to be Tshirts or pants? For this kind of question, we can calculate the probability scores for the OOD data, and compare this to the ID data.

PYTHON

# Predict probabilities using the model on OOD data (Sandals)
ood_probs = model.predict_proba(ood_data_flat)
avg_ood_prob = np.mean(ood_probs, 0)
print(f"Avg. probability of sandal being T-shirt: {avg_ood_prob[0]:.4f}")
print(f"Avg. probability of sandal being pants: {avg_ood_prob[1]:.4f}")

id_probs = model.predict_proba(train_data_flat)
id_probs_shirts = id_probs[train_labels==0,:]
id_probs_pants = id_probs[train_labels==1,:]
avg_tshirt_prob = np.mean(id_probs_shirts, 0)
avg_pants_prob = np.mean(id_probs_pants, 0)

print()
print(f"Avg. probability of T-shirt being T-shirt: {avg_tshirt_prob[0]:.4f}")
print(f"Avg. probability of pants being pants: {avg_pants_prob[1]:.4f}")

Based on the difference in averages here, it looks like softmax may provide at least a somewhat useful signal in separating ID and OOD data. Let’s take a closer look by plotting histograms of all probability scores across our classes of interest (ID-Tshirt, ID-Pants, and OOD).

PYTHON

# Creating the figure and subplots
fig, axes = plt.subplots(1, 3, figsize=(15, 4), sharey=False)
bins=60
# Plotting the histogram of probabilities for OOD data (Sandals)
axes[0].hist(ood_probs[:, 0], bins=bins, alpha=0.5, label='T-shirt probability')
axes[0].set_xlabel('Probability')
axes[0].set_ylabel('Frequency')
axes[0].set_title('OOD Data (Sandals)')
axes[0].legend()

# Plotting the histogram of probabilities for ID data (T-shirt)
axes[1].hist(id_probs_shirts[:, 0], bins=bins, alpha=0.5, label='T-shirt probability', color='orange')
axes[1].set_xlabel('Probability')
axes[1].set_title('ID Data (T-shirt/top)')
axes[1].legend()

# Plotting the histogram of probabilities for ID data (Pants)
axes[2].hist(id_probs_pants[:, 1], bins=bins, alpha=0.5, label='Pants probability', color='green')
axes[2].set_xlabel('Probability')
axes[2].set_title('ID Data (Pants)')
axes[2].legend()

# Adjusting layout
plt.tight_layout()
plt.savefig('../images/OOD-detection_histograms.png', dpi=300, bbox_inches='tight')
# Displaying the plot
plt.show()

Histograms of ID oand OOD data Alternatively, for a better comparison across all three classes, we can use a probability density plot. This will allow for an easier comparison when the counts across classes lie on vastly different sclaes (i.e., max of 35 vs max of 5000).

PYTHON

from scipy.stats import gaussian_kde

# Create figure
plt.figure(figsize=(10, 6))

# Define bins
alpha = 0.4

# Plot PDF for ID T-shirt (T-shirt probability)
density_id_shirts = gaussian_kde(id_probs_shirts[:, 0])
x_id_shirts = np.linspace(0, 1, 1000)
plt.plot(x_id_shirts, density_id_shirts(x_id_shirts), label='ID T-shirt (T-shirt probability)', color='orange', alpha=alpha)

# Plot PDF for ID Pants (Pants probability)
density_id_pants = gaussian_kde(id_probs_pants[:, 0])
x_id_pants = np.linspace(0, 1, 1000)
plt.plot(x_id_pants, density_id_pants(x_id_pants), label='ID Pants (T-shirt probability)', color='green', alpha=alpha)

# Plot PDF for OOD (T-shirt probability)
density_ood = gaussian_kde(ood_probs[:, 0])
x_ood = np.linspace(0, 1, 1000)
plt.plot(x_ood, density_ood(x_ood), label='OOD (T-shirt probability)', color='blue', alpha=alpha)

# Adding labels and title
plt.xlabel('Probability')
plt.ylabel('Density')
plt.title('Probability Density Distributions for OOD and ID Data')
plt.legend()

plt.savefig('../images/OOD-detection_PSDs.png', dpi=300, bbox_inches='tight')

# Displaying the plot
plt.show()

Probability densities Unfortunately, we observe a significant amount of overlap between OOD data and high T-shirt probability. Furthermore, the blue line doesn’t seem to decrease much as you move from 0.9 to 1, suggesting that even a very high threshold is likely to lead to OOD contamination (while also tossing out a significant portion of ID data).

For pants, the problem is much less severe. It looks like a low threshold (on this T-shirt probability scale) can separate nearly all OOD samples from being pants.

Setting a threshold

Let’s put our observations to the test and produce a confusion matrix that includes ID-pants, ID-Tshirts, and OOD class labels. We’ll start with a high threshold of 0.9 to see how that performs.

PYTHON

def softmax_thresh_classifications(probs, threshold):
    classifications = np.where(probs[:, 1] >= threshold, 1,  # classified as pants
                               np.where(probs[:, 0] >= threshold, 0,  # classified as shirts
                                        -1))  # classified as OOD
    return classifications

PYTHON

from sklearn.metrics import precision_recall_fscore_support

# Assuming ood_probs, id_probs, and train_labels are defined
# Threshold values
upper_threshold = 0.9

# Classifying OOD examples (sandals)
ood_classifications = softmax_thresh_classifications(ood_probs, upper_threshold)

# Classifying ID examples (T-shirts and pants)
id_classifications = softmax_thresh_classifications(id_probs, upper_threshold)

# Combine OOD and ID classifications and true labels
all_predictions = np.concatenate([ood_classifications, id_classifications])
all_true_labels = np.concatenate([-1 * np.ones(ood_classifications.shape), train_labels])

# Confusion matrix
cm = confusion_matrix(all_true_labels, all_predictions, labels=[0, 1, -1])

# Plotting the confusion matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Shirt", "Pants", "OOD"])
disp.plot(cmap=plt.cm.Blues)
plt.title('Confusion Matrix for OOD and ID Classification')

plt.savefig('../images/OOD-detection_ID-OOD-confusion-matrix1.png', dpi=300, bbox_inches='tight')

plt.show()

# Looking at F1, precision, and recall
precision, recall, f1, _ = precision_recall_fscore_support(all_true_labels, all_predictions, labels=[0, 1], average='macro') # discuss macro vs micro .

print(f"F1: {f1}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")

Probability densities Even with a high threshold of 0.9, we end up with nearly a couple hundred OOD samples classified as ID. In addition, over 800 ID samples had to be tossed out due to uncertainty.

Quick exercise

What threhsold is required to ensure that no OOD samples are incorrectly considered as IID? What percentage of ID samples are mistaken as OOD at this threshold? Answer: 0.9999, (3826+2414)/(3826+2414+2174+3586)=52%

With a very conservative threshold, we can make sure very few OOD samples are incorrectly classified as ID. However, the flip side is that conservative thresholds tend to incorrectly classify many ID samples as being OOD. In this case, we incorrectly assume almost 20% of shirts are OOD samples.

Iterative Threshold Determination

In practice, selecting an appropriate threshold is an iterative process that balances the trade-off between correctly identifying in-distribution (ID) data and accurately flagging out-of-distribution (OOD) data. Here’s how you can iteratively determine the threshold:

  • Define Evaluation Metrics: While confusion matrices are an excellent tool when you’re ready to more closely examine the data, we need a single metric that can summarize threshold performance so we can easily compare across threshold. Common metrics include accuracy, precision, recall, or the F1 score for both ID and OOD detection.

  • Evaluate Over a Range of Thresholds: Test different threshold values and evaluate the performance on a validation set containing both ID and OOD data.

  • Select the Optimal Threshold: Choose the threshold that provides the best balance according to your chosen metrics.

Use the below code to determine what threshold should be set to ensure precision = 100%. What threshold is required for recall to be 100%? What threshold gives the highest F1 score?

Callout on averaging schemes

F1 scores can be calculated per class, and then averaged in different ways (macro, micro, or weighted) when dealing with multiclass or multilabel classification problems. Here are the key types of averaging methods:

  • Macro-Averaging: Calculates the F1 score for each class independently and then takes the average of these scores. This treats all classes equally, regardless of their support (number of true instances for each class).

  • Micro-Averaging: Aggregates the contributions of all classes to compute the average F1 score. This is typically used for imbalanced datasets as it gives more weight to classes with more instances.

  • Weighted-Averaging: Calculates the F1 score for each class independently and then takes the average, weighted by the number of true instances for each class. This accounts for class imbalance by giving more weight to classes with more instances.

Callout on including OOD data in F1 calculation

PYTHON

# from sklearn.metrics import precision_recall_fscore_support, accuracy_score

def eval_softmax_thresholds(thresholds, ood_probs, id_probs):
    # Store evaluation metrics for each threshold
    precisions = []
    recalls = []
    f1_scores = []
    
    for threshold in thresholds:
        # Classifying OOD examples (sandals)
        ood_classifications = softmax_thresh_classifications(ood_probs, threshold)
        
        # Classifying ID examples (T-shirts and pants)
        id_classifications = softmax_thresh_classifications(id_probs, threshold)
        
        # Combine OOD and ID classifications and true labels
        all_predictions = np.concatenate([ood_classifications, id_classifications])
        all_true_labels = np.concatenate([-1 * np.ones(ood_classifications.shape), train_labels])
        
        # Evaluate metrics
        precision, recall, f1, _ = precision_recall_fscore_support(all_true_labels, all_predictions, labels=[0, 1], average='macro') # discuss macro vs micro .
        precisions.append(precision)
        recalls.append(recall)
        f1_scores.append(f1)
        
    return precisions, recalls, f1_scores

PYTHON

# Define thresholds to evaluate
thresholds = np.linspace(.5, 1, 50)

# Evaluate on all thresholds
precisions, recalls, f1_scores = eval_softmax_thresholds(thresholds, ood_probs, id_probs)

PYTHON

def plot_metrics_vs_thresholds(thresholds, f1_scores, precisions, recalls, OOD_signal):
    # Find the best thresholds for each metric
    best_f1_index = np.argmax(f1_scores)
    best_f1_threshold = thresholds[best_f1_index]
    
    best_precision_index = np.argmax(precisions)
    best_precision_threshold = thresholds[best_precision_index]
    
    best_recall_index = np.argmax(recalls)
    best_recall_threshold = thresholds[best_recall_index]
    
    print(f"Best F1 threshold: {best_f1_threshold}, F1 Score: {f1_scores[best_f1_index]}")
    print(f"Best Precision threshold: {best_precision_threshold}, Precision: {precisions[best_precision_index]}")
    print(f"Best Recall threshold: {best_recall_threshold}, Recall: {recalls[best_recall_index]}")

    # Create a new figure
    fig, ax = plt.subplots(figsize=(12, 8))

    # Plot metrics as functions of the threshold
    ax.plot(thresholds, precisions, label='Precision', color='g')
    ax.plot(thresholds, recalls, label='Recall', color='b')
    ax.plot(thresholds, f1_scores, label='F1 Score', color='r')
    
    # Add best threshold indicators
    ax.axvline(x=best_f1_threshold, color='r', linestyle='--', label=f'Best F1 Threshold: {best_f1_threshold:.2f}')
    ax.axvline(x=best_precision_threshold, color='g', linestyle='--', label=f'Best Precision Threshold: {best_precision_threshold:.2f}')
    ax.axvline(x=best_recall_threshold, color='b', linestyle='--', label=f'Best Recall Threshold: {best_recall_threshold:.2f}')
    ax.set_xlabel(f'{OOD_signal} Threshold')
    ax.set_ylabel('Metric Value')
    ax.set_title('Evaluation Metrics as Functions of Threshold')
    ax.legend()

    return fig, best_f1_threshold, best_precision_threshold, best_recall_threshold

PYTHON

fig, best_f1_threshold, best_precision_threshold, best_recall_threshold = plot_metrics_vs_thresholds(thresholds, f1_scores, precisions, recalls, 'Softmax')
fig.savefig('../images/OOD-detection_metrics_vs_softmax-thresholds.png', dpi=300, bbox_inches='tight')
OOD-detection_metrics_vs_softmax-thresholds
OOD-detection_metrics_vs_softmax-thresholds

PYTHON

# Threshold values
upper_threshold = best_f1_threshold
# upper_threshold = best_precision_threshold

# Classifying OOD examples (sandals)
ood_classifications = softmax_thresh_classifications(ood_probs, upper_threshold)

# Classifying ID examples (T-shirts and pants)
id_classifications = softmax_thresh_classifications(id_probs, upper_threshold)

# Combine OOD and ID classifications and true labels
all_predictions = np.concatenate([ood_classifications, id_classifications])
all_true_labels = np.concatenate([-1 * np.ones(ood_classifications.shape), train_labels])

# Confusion matrix
cm = confusion_matrix(all_true_labels, all_predictions, labels=[0, 1, -1])

# Plotting the confusion matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Shirt", "Pants", "OOD"])
disp.plot(cmap=plt.cm.Blues)
plt.title('Confusion Matrix for OOD and ID Classification')
plt.savefig('../images/OOD-detection_ID-OOD-confusion-matrix2.png', dpi=300, bbox_inches='tight')
plt.show()
Optimized threshold confusion matrix
Optimized threshold confusion matrix

Key Points

  • Softmax-based OOD detection uses the model’s output probabilities to identify instances that do not belong to the training distribution.
  • Threshold selection is critical and involves trade-offs between retaining in-distribution data and detecting OOD samples.
  • Visualizations such as PCA and probability density plots help illustrate how OOD data overlaps with in-distribution data in feature space.
  • While simple and widely used, softmax-based methods have limitations, including sensitivity to threshold choices and reduced reliability in high-dimensional settings.
  • Understanding softmax-based OOD detection lays the groundwork for exploring more advanced techniques like energy-based detection.

Content from OOD detection: energy


Last updated on 2024-11-15 | Edit this page

Estimated time: 0 minutes

Overview

Questions

  • What are energy-based methods for out-of-distribution (OOD) detection, and how do they compare to softmax-based approaches?
  • How does the energy metric enhance separability between in-distribution and OOD data?
  • What are the challenges and limitations of energy-based OOD detection methods?

Objectives

  • Understand the concept of energy-based OOD detection and its theoretical foundations.
  • Compare energy-based methods to softmax-based approaches, highlighting their strengths and limitations.
  • Learn how to implement energy-based OOD detection using tools like PyTorch-OOD.
  • Explore challenges in applying energy-based methods, including threshold tuning and generalizability to diverse OOD scenarios.

Example 2: Energy-Based OOD Detection

Traditional approaches, such as softmax-based methods, rely on output probabilities to flag OOD data. While simple and intuitive, these methods often struggle to distinguish OOD data effectively in complex scenarios, especially in high-dimensional spaces.

Energy-based OOD detection offers a modern and robust alternative. This “output-based” approach leverages the energy score, a scalar value derived from a model’s output logits, to measure the compatibility between input data and the model’s learned distribution. The energy score directly ties to the Gibbs distribution, which allows for a nuanced interpretation of data compatibility. By computing energy scores and interpreting them probabilistically, energy-based methods enhance separability between in-distribution (ID) and OOD data. These techniques, particularly effective with neural networks, address some of the key limitations of softmax-based approaches.

In this episode, we will explore the theoretical foundations of energy-based OOD detection, implement it using the PyTorch-OOD library, and compare its performance to softmax-based methods. Along the way, we will provide intuitive explanations of key concepts like the Gibbs distribution to ensure accessibility for ML practitioners. We will also discuss the challenges of energy-based methods, including threshold tuning and generalization to diverse OOD scenarios. This understanding will set the stage for hybrid and training-time regularization methods discussed in future episodes.

intuition behind the Gibbs distribution and energy scores

The Gibbs distribution is a probability distribution used to model systems in equilibrium, and it connects naturally to the concept of energy. In the context of machine learning:

  • Energy as a compatibility measure: Think of energy as a score that measures how “compatible” a data point is with the model’s learned distribution. A lower energy score indicates higher compatibility (i.e., the model sees the input as likely to belong to the training distribution).

  • From energy to probability: The Gibbs distribution converts energy into probabilities. For an input \(x\), the probability of observing \(x\) is proportional to \(e^{-\text{Energy}(x)}\). This exponential relationship means that even small differences in energy can lead to significant changes in probability, making energy scores sensitive and effective for OOD detection.

  • Why energy works for OOD detection: OOD data often results in higher energy scores because the model’s output logits are less aligned with any known class. By setting a threshold on energy scores, we can distinguish OOD data from ID data more effectively than with softmax probabilities alone.

For practitioners, the Gibbs distribution provides a bridge between abstract “energy” values and interpretable probabilities, helping us reason about uncertainty and compatibility in a model’s predictions. * E(x, y) = energy value

  • if x and y are “compatitble”, lower energy

  • Energy can be turned into probability through Gibbs distribution

    • looks at integral over all possible y’s
  • With energy scores, ID and OOD distributions become much more separable

Learn more: Liu et al., Energy-based Out-of-distribution Detection, NeurIPS 2020; https://arxiv.org/pdf/2010.03759

Introducing PyTorch OOD

The PyTorch-OOD library provides methods for OOD detection and other closely related fields, such as anomoly detection or novelty detection. Visit the docs to learn more: pytorch-ood.readthedocs.io/en/latest/info.html

This library will provide a streamlined way to calculate both energy and softmax scores from a trained model. ### Setup example In this example, we will train a CNN model on the FashionMNIST dataset. We will then repeat a similar process as we did with softmax scores to evaluate how well the energy metric can separate ID and OOD data.

We’ll start by fresh by loading our data again. This time, let’s treat all remaining classes in the MNIST fashion dataset as OOD. This should yield a more robust model that is more reliable when presented with all kinds of data.

PYTHON

train_data, test_data, ood_data, train_labels, test_labels, ood_labels = prep_ID_OOD_datasests([0,1], list(range(2,10))) # use remaining 8 classes in dataset as OOD
fig = plot_data_sample(train_data, ood_data)
fig.savefig('../images/OOD-detection_image-data-preview.png', dpi=300, bbox_inches='tight')
plt.show()

Visualizing OOD and ID data

UMAP (or similar)

Recall in our previous example, we used PCA to visualize the ID and OOD data distributions. This was appropriate given that we were evaluating OOD/ID data in the context of a linear model. However, when working with nonlinear models such as CNNs, it makes more sense to investigate how the data is represented in a nonlinear space. Nonlinear embedding methods, such as Uniform Manifold Approximation and Projection (UMAP), are more suitable in such scenarios.

UMAP is a non-linear dimensionality reduction technique that preserves both the global structure and the local neighborhood relationships in the data. UMAP is often better at maintaining the continuity of data points that lie on non-linear manifolds. It can reveal nonlinear patterns and structures that PCA might miss, making it a valuable tool for analyzing ID and OOD distributions.

PYTHON

plot_umap = True # leave off for now to save time testing downstream materials
if plot_umap:
    import umap
    # Flatten images for PCA and logistic regression
    train_data_flat = train_data.reshape((train_data.shape[0], -1))
    test_data_flat = test_data.reshape((test_data.shape[0], -1))
    ood_data_flat = ood_data.reshape((ood_data.shape[0], -1))
    
    print(f'train_data_flat.shape={train_data_flat.shape}')
    print(f'test_data_flat.shape={test_data_flat.shape}')
    print(f'ood_data_flat.shape={ood_data_flat.shape}')
    
    # Perform UMAP to visualize the data
    umap_reducer = umap.UMAP(n_components=2, random_state=42)
    combined_data = np.vstack([train_data_flat, ood_data_flat])
    combined_labels = np.hstack([train_labels, np.full(ood_data_flat.shape[0], 2)])  # Use 2 for OOD class
    
    umap_results = umap_reducer.fit_transform(combined_data)
    
    # Split the results back into in-distribution and OOD data
    umap_in_dist = umap_results[:len(train_data_flat)]
    umap_ood = umap_results[len(train_data_flat):]

The warning message indicates that UMAP has overridden the n_jobs parameter to 1 due to the random_state being set. This behavior ensures reproducibility by using a single job. If you want to avoid the warning and still use parallelism, you can remove the random_state parameter. However, removing random_state will mean that the results might not be reproducible.

PYTHON

if plot_umap:
    umap_alpha = .02

    # Plotting UMAP components
    plt.figure(figsize=(10, 6))
    
    # Plot in-distribution data
    scatter1 = plt.scatter(umap_in_dist[train_labels == 0, 0], umap_in_dist[train_labels == 0, 1], c='blue', label='T-shirts (ID)', alpha=umap_alpha)
    scatter2 = plt.scatter(umap_in_dist[train_labels == 1, 0], umap_in_dist[train_labels == 1, 1], c='red', label='Trousers (ID)', alpha=umap_alpha)
    
    # Plot OOD data
    scatter3 = plt.scatter(umap_ood[:, 0], umap_ood[:, 1], c='green', label='OOD', edgecolor='k', alpha=alpha)
    
    # Create a single legend for all classes
    plt.legend(handles=[scatter1, scatter2, scatter3], loc="upper right")
    plt.xlabel('First UMAP Component')
    plt.ylabel('Second UMAP Component')
    plt.title('UMAP of In-Distribution and OOD Data')
    plt.show()

Train CNN

PYTHON

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torch.nn.functional as F

# Convert to PyTorch tensors and normalize
train_data_tensor = torch.tensor(train_data, dtype=torch.float32).unsqueeze(1) / 255.0
test_data_tensor = torch.tensor(test_data, dtype=torch.float32).unsqueeze(1) / 255.0
ood_data_tensor = torch.tensor(ood_data, dtype=torch.float32).unsqueeze(1) / 255.0

train_labels_tensor = torch.tensor(train_labels, dtype=torch.long)
test_labels_tensor = torch.tensor(test_labels, dtype=torch.long)

train_dataset = torch.utils.data.TensorDataset(train_data_tensor, train_labels_tensor)
test_dataset = torch.utils.data.TensorDataset(test_data_tensor, test_labels_tensor)
ood_dataset = torch.utils.data.TensorDataset(ood_data_tensor, torch.zeros(ood_data_tensor.shape[0], dtype=torch.long))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
ood_loader = torch.utils.data.DataLoader(ood_dataset, batch_size=64, shuffle=False)

# Define a simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(64*5*5, 128)  # Updated this line
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 64*5*5)  # Updated this line
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_model(model, train_loader, criterion, optimizer, epochs=5):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')

train_model(model, train_loader, criterion, optimizer)

PYTHON

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Function to plot confusion matrix
def plot_confusion_matrix(labels, predictions, title):
    cm = confusion_matrix(labels, predictions, labels=[0, 1])
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["T-shirt/top", "Trouser"])
    disp.plot(cmap=plt.cm.Blues)
    plt.title(title)
    plt.show()

# Function to evaluate model on a dataset
def evaluate_model(model, dataloader, device):
    model.eval()
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(preds.cpu().numpy())
    return np.array(all_labels), np.array(all_predictions)

# Evaluate on train data
train_labels, train_predictions = evaluate_model(model, train_loader, device)
plot_confusion_matrix(train_labels, train_predictions, "Confusion Matrix for Train Data")

# Evaluate on test data
test_labels, test_predictions = evaluate_model(model, test_loader, device)
plot_confusion_matrix(test_labels, test_predictions, "Confusion Matrix for Test Data")

# Evaluate on OOD data
ood_labels, ood_predictions = evaluate_model(model, ood_loader, device)
plot_confusion_matrix(ood_labels, ood_predictions, "Confusion Matrix for OOD Data")

PYTHON

from scipy.stats import gaussian_kde
from pytorch_ood.detector import EnergyBased
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

# Compute softmax scores
def get_softmax_scores(model, dataloader):
    model.eval()
    softmax_scores = []
    with torch.no_grad():
        for inputs, _ in dataloader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            softmax = torch.nn.functional.softmax(outputs, dim=1)
            softmax_scores.extend(softmax.cpu().numpy())
    return np.array(softmax_scores)

id_softmax_scores = get_softmax_scores(model, test_loader)
ood_softmax_scores = get_softmax_scores(model, ood_loader)

# Initialize the energy-based OOD detector
energy_detector = EnergyBased(model, t=1.0)

# Compute energy scores
def get_energy_scores(detector, dataloader):
    scores = []
    detector.model.eval()
    with torch.no_grad():
        for inputs, _ in dataloader:
            inputs = inputs.to(device)
            score = detector.predict(inputs)
            scores.extend(score.cpu().numpy())
    return np.array(scores)

id_energy_scores = get_energy_scores(energy_detector, test_loader)
ood_energy_scores = get_energy_scores(energy_detector, ood_loader)

import matplotlib.pyplot as plt


# Plot PSDs

# Function to plot PSD
def plot_psd(id_scores, ood_scores, method_name):
    plt.figure(figsize=(12, 6))
    alpha = 0.3

    # Plot PSD for ID scores
    id_density = gaussian_kde(id_scores)
    x_id = np.linspace(id_scores.min(), id_scores.max(), 1000)
    plt.plot(x_id, id_density(x_id), label=f'ID ({method_name})', color='blue', alpha=alpha)

    # Plot PSD for OOD scores
    ood_density = gaussian_kde(ood_scores)
    x_ood = np.linspace(ood_scores.min(), ood_scores.max(), 1000)
    plt.plot(x_ood, ood_density(x_ood), label=f'OOD ({method_name})', color='red', alpha=alpha)

    plt.xlabel('Score')
    plt.ylabel('Density')
    plt.title(f'Probability Density Distributions for {method_name} Scores')
    plt.legend()
    plt.show()

# Plot PSD for softmax scores
plot_psd(id_softmax_scores[:, 1], ood_softmax_scores[:, 1], 'Softmax')

# Plot PSD for energy scores
plot_psd(id_energy_scores, ood_energy_scores, 'Energy')

PYTHON

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix, ConfusionMatrixDisplay

# Define thresholds to evaluate
thresholds = np.linspace(id_energy_scores.min(), id_energy_scores.max(), 50)

# Store evaluation metrics for each threshold
accuracies = []
precisions = []
recalls = []
f1_scores = []

# True labels for OOD data (since they are not part of the original labels)
ood_true_labels = np.full(len(ood_energy_scores), -1)

# We need the test_labels to be aligned with the ID data
id_true_labels = test_labels[:len(id_energy_scores)]

for threshold in thresholds:
    # Classify OOD examples based on energy scores
    ood_classifications = np.where(ood_energy_scores >= threshold, -1,  # classified as OOD
                                   np.where(ood_energy_scores < threshold, 0, -1))  # classified as ID

    # Classify ID examples based on energy scores
    id_classifications = np.where(id_energy_scores >= threshold, -1,  # classified as OOD
                                  np.where(id_energy_scores < threshold, id_true_labels, -1))  # classified as ID

    # Combine OOD and ID classifications and true labels
    all_predictions = np.concatenate([ood_classifications, id_classifications])
    all_true_labels = np.concatenate([ood_true_labels, id_true_labels])

    # Evaluate metrics
    precision, recall, f1, _ = precision_recall_fscore_support(all_true_labels, all_predictions, labels=[0, 1], average='macro')#, zero_division=0)
    accuracy = accuracy_score(all_true_labels, all_predictions)

    accuracies.append(accuracy)
    precisions.append(precision)
    recalls.append(recall)
    f1_scores.append(f1)

# Find the best thresholds for each metric
best_f1_index = np.argmax(f1_scores)
best_f1_threshold = thresholds[best_f1_index]

best_precision_index = np.argmax(precisions)
best_precision_threshold = thresholds[best_precision_index]

best_recall_index = np.argmax(recalls)
best_recall_threshold = thresholds[best_recall_index]

print(f"Best F1 threshold: {best_f1_threshold}, F1 Score: {f1_scores[best_f1_index]}")
print(f"Best Precision threshold: {best_precision_threshold}, Precision: {precisions[best_precision_index]}")
print(f"Best Recall threshold: {best_recall_threshold}, Recall: {recalls[best_recall_index]}")

# Plot metrics as functions of the threshold
plt.figure(figsize=(12, 8))
plt.plot(thresholds, precisions, label='Precision', color='g')
plt.plot(thresholds, recalls, label='Recall', color='b')
plt.plot(thresholds, f1_scores, label='F1 Score', color='r')

# Add best threshold indicators
plt.axvline(x=best_f1_threshold, color='r', linestyle='--', label=f'Best F1 Threshold: {best_f1_threshold:.2f}')
plt.axvline(x=best_precision_threshold, color='g', linestyle='--', label=f'Best Precision Threshold: {best_precision_threshold:.2f}')
plt.axvline(x=best_recall_threshold, color='b', linestyle='--', label=f'Best Recall Threshold: {best_recall_threshold:.2f}')

plt.xlabel('Threshold')
plt.ylabel('Metric Value')
plt.title('Evaluation Metrics as Functions of Threshold (Energy-Based OOD Detection)')
plt.legend()
plt.show()

PYTHON


import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix, ConfusionMatrixDisplay

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

def evaluate_ood_detection(id_scores, ood_scores, id_true_labels, id_predictions, ood_predictions, score_type='energy'):
    """
    Evaluate OOD detection based on either energy scores or softmax scores.

    Parameters:
    - id_scores: np.array, scores for in-distribution (ID) data
    - ood_scores: np.array, scores for out-of-distribution (OOD) data
    - id_true_labels: np.array, true labels for ID data
    - id_predictions: np.array, predicted labels for ID data
    - ood_predictions: np.array, predicted labels for OOD data
    - score_type: str, type of score used ('energy' or 'softmax')

    Returns:
    - Best thresholds for F1, Precision, and Recall
    - Plots of Precision, Recall, and F1 Score as functions of the threshold
    """
    # Define thresholds to evaluate
    if score_type == 'softmax':
        thresholds = np.linspace(0.5, 1.0, 200)
    else:
        thresholds = np.linspace(id_scores.min(), id_scores.max(), 50)

    # Store evaluation metrics for each threshold
    accuracies = []
    precisions = []
    recalls = []
    f1_scores = []

    # True labels for OOD data (since they are not part of the original labels)
    if score_type == "energy":
        ood_true_labels = np.full(len(ood_scores), -1)
    else:
        ood_true_labels = np.full(len(ood_scores[:,0]), -1)

    for threshold in thresholds:
        # Classify OOD examples based on scores
        if score_type == 'energy':
            ood_classifications = np.where(ood_scores >= threshold, -1, ood_predictions)
            id_classifications = np.where(id_scores >= threshold, -1, id_predictions)
        elif score_type == 'softmax':
            ood_classifications = np.where(ood_scores[:,0] <= threshold, -1, ood_predictions)
            id_classifications = np.where(id_scores[:,0] <= threshold, -1, id_predictions)
        else:
            raise ValueError("Invalid score_type. Use 'energy' or 'softmax'.")

        # Combine OOD and ID classifications and true labels
        all_predictions = np.concatenate([ood_classifications, id_classifications])
        all_true_labels = np.concatenate([ood_true_labels, id_true_labels])

        # Evaluate metrics
        precision, recall, f1, _ = precision_recall_fscore_support(all_true_labels, all_predictions, labels=[-1, 0], average='macro', zero_division=0)
        accuracy = accuracy_score(all_true_labels, all_predictions)

        accuracies.append(accuracy)
        precisions.append(precision)
        recalls.append(recall)
        f1_scores.append(f1)

    # Find the best thresholds for each metric
    best_f1_index = np.argmax(f1_scores)
    best_f1_threshold = thresholds[best_f1_index]

    best_precision_index = np.argmax(precisions)
    best_precision_threshold = thresholds[best_precision_index]

    best_recall_index = np.argmax(recalls)
    best_recall_threshold = thresholds[best_recall_index]

    print(f"Best F1 threshold: {best_f1_threshold}, F1 Score: {f1_scores[best_f1_index]}")
    print(f"Best Precision threshold: {best_precision_threshold}, Precision: {precisions[best_precision_index]}")
    print(f"Best Recall threshold: {best_recall_threshold}, Recall: {recalls[best_recall_index]}")

    # Plot metrics as functions of the threshold
    plt.figure(figsize=(12, 8))
    plt.plot(thresholds, precisions, label='Precision', color='g')
    plt.plot(thresholds, recalls, label='Recall', color='b')
    plt.plot(thresholds, f1_scores, label='F1 Score', color='r')

    # Add best threshold indicators
    plt.axvline(x=best_f1_threshold, color='r', linestyle='--', label=f'Best F1 Threshold: {best_f1_threshold:.2f}')
    plt.axvline(x=best_precision_threshold, color='g', linestyle='--', label=f'Best Precision Threshold: {best_precision_threshold:.2f}')
    plt.axvline(x=best_recall_threshold, color='b', linestyle='--', label=f'Best Recall Threshold: {best_recall_threshold:.2f}')

    plt.xlabel('Threshold')
    plt.ylabel('Metric Value')
    plt.title(f'Evaluation Metrics as Functions of Threshold ({score_type.capitalize()}-Based OOD Detection)')
    plt.legend()
    plt.show()

    # plot confusion matrix
    # Threshold value for the energy score
    upper_threshold = best_f1_threshold  # Using the best F1 threshold from the previous calculation
    if score_type == 'energy':
        # Classifying OOD examples based on energy scores
        ood_classifications = np.where(ood_energy_scores >= upper_threshold, -1,  # classified as OOD
                                  np.where(ood_energy_scores < upper_threshold, 0, -1))  # classified as ID
        # Classifying ID examples based on energy scores
        id_classifications = np.where(id_energy_scores >= upper_threshold, -1,  # classified as OOD
                                  np.where(id_energy_scores < upper_threshold, id_true_labels, -1))  # classified as ID
    elif score_type == 'softmax':
        # Classifying OOD examples based on softmax scores
        ood_classifications = softmax_thresh_classifications(ood_scores, upper_threshold)

        # Classifying ID examples based on softmax scores
        id_classifications = softmax_thresh_classifications(id_scores, upper_threshold)
    # Combine OOD and ID classifications and true labels
    all_predictions = np.concatenate([ood_classifications, id_classifications])
    all_true_labels = np.concatenate([ood_true_labels, id_true_labels])
    # Confusion matrix
    cm = confusion_matrix(all_true_labels, all_predictions, labels=[0, 1, -1])

    # Plotting the confusion matrix
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Shirt", "Pants", "OOD"])
    disp.plot(cmap=plt.cm.Blues)
    plt.title(f'Confusion Matrix for OOD and ID Classification ({score_type.capitalize()}-Based)')
    plt.show()


    return best_f1_threshold, best_precision_threshold, best_recall_threshold

# Example usage
# Assuming id_energy_scores, ood_energy_scores, id_true_labels, and test_labels are already defined
best_f1_threshold, best_precision_threshold, best_recall_threshold = evaluate_ood_detection(id_energy_scores, ood_energy_scores, test_labels, test_predictions, ood_predictions, score_type='energy')
best_f1_threshold, best_precision_threshold, best_recall_threshold = evaluate_ood_detection(id_softmax_scores, ood_softmax_scores, test_labels, test_predictions, ood_predictions, score_type='softmax')

Limitations of our approach thus far

  • Focus on single OOD class: More reliable/accurate thresholds can/should be obtained using a wider variety (more classes) and larger sample of OOD data. This is part of the challenge of OOD detection which is that space of OOD data is vast. Possible exercise: Redo thresholding using all remaining classes in dataset.

References and supplemental resources

Key Points

  • Energy-based OOD detection is a modern and more robust alternative to softmax-based methods, leveraging energy scores to improve separability between in-distribution and OOD data.
  • By calculating an energy value for each input, these methods provide a more nuanced measure of compatibility between data and the model’s learned distribution.
  • Non-linear visualizations, like UMAP, offer better insights into how OOD and ID data are represented in high-dimensional feature spaces compared to linear methods like PCA.
  • PyTorch-OOD simplifies the implementation of energy-based and other OOD detection methods, making it accessible for real-world applications.
  • While energy-based methods excel in many scenarios, challenges include tuning thresholds across diverse OOD classes and ensuring generalizability to unseen distributions.
  • Transitioning to energy-based detection lays the groundwork for exploring training-time regularization and hybrid approaches.

Content from OOD detection: distance-based


Last updated on 2024-11-15 | Edit this page

Estimated time: 0 minutes

Overview

Questions

  • How do distance-based methods like Mahalanobis distance and KNN work for OOD detection?
  • What is contrastive learning and how does it improve feature representations?
  • How does contrastive learning enhance the effectiveness of distance-based OOD detection methods?

Objectives

  • Gain a thorough understanding of distance-based OOD detection methods, including Mahalanobis distance and KNN.
  • Learn the principles of contrastive learning and its role in improving feature representations.
  • Explore the synergy between contrastive learning and distance-based OOD detection methods to enhance detection performance.

Example 3: Distance-Based Methods

Lee et al., A simple unified framework for detecting out-of-distribution samples and adversarial attacks. NeurIPS 2018.

With softmax and energy-based methods, we focus on the models outputs to determine a threshold that defines ID and OOD data. With distance-based methods, we focus on the feature representations learned by the model.

In the case of neural networks, a common approach is to use the penultimate layer as a feature representation that can define an ID clusters for each class. You can then use distance to the closesent centroid as a proxy for OOD measure.

Mahalanobis distance (parametric)

Model the feature space as a mixture of multivariate Gaussian distribution, one for each class. use distance to the closest centroid as proxy for OOD measure

Limiations of parametric approach

  • Strong distributional assumption (features may not necessarily be Gassian-distributed)
  • Suboptimal embedding

Nearest Neighbor Distance (non-parametric)

Sun et al., Out-of-distribution Detection with Deep Nearest Neighbors, ICML 2022

  • Sample considered OOD if it has a large KNN distrance w.r.t. training data (and vice versa)
  • No distributional assumptions about underlying embedding space. Stronger generality and flexibility than mahalanobis distancew

CIDER

This one might be out of scope…

Ming et al., How to Exploit Hyperspherical Embeddings for Out-of-Distribution Detection # Contrastive Learning

  • Explain the basic idea of contrastive learning: learning representations by contrasting positive and negative pairs.
  • Highlight the role of contrastive learning in learning discriminative features that can separate in-distribution (ID) from OOD data more effectively.
  • Illustrate how contrastive learning improves the feature space, making distance-based methods (like Mahalanobis and KNN) more effective.
  • Provide examples or case studies where contrastive learning has been applied to enhance OOD detection. # Example X: Comparing feature representations with and without contrastive learning

Returning to UMAP

Notice how in our UMAP visualization, we say three distinct clusters representing each class. However, our model still confidently rated many sandals as being tshirts. The crux of this issue is that models do not know what they don’t know. They simply draw classifcation boundaries between the classes available to them during training.

One way to get around this problem is to train models to learn discriminative features…

Contrastive learning

In this experiment, we use both a traditional neural network and a contrastive learning model to classify images from the Fashion MNIST dataset, focusing on T-shirts (class 0) and Trousers (class 1). Additionally, we evaluate the models on out-of-distribution (OOD) data, specifically Sandals (class 5). To visualize the models’ learned features, we extract features from specific layers of the neural networks and reduce their dimensionality using UMAP.

Overview of steps

1) Train model

  • With or without contrastive learning
  • Focusing on T-shirts (class 0) and Trousers (class 1)
  • Additionally, we evaluate the models on out-of-distribution (OOD) data, specifically Sandals (class 5)

2) Feature Extraction:

  • After training, we set the models to evaluation mode to prevent updates to the model parameters.
  • For each subset of the data (training, validation, and OOD), we pass the images through the entire network up to the first fully connected layer.
  • The output of this layer, which captures high-level features and abstractions, is then used as a 1D feature vector.
  • These feature vectors are detached from the computational graph and converted to NumPy arrays for further processing.

3) Dimensionality Reduction and Visualization:

  • We combine the feature vectors from the training, validation, and OOD data into a single dataset.
  • UMAP (Uniform Manifold Approximation and Projection) is used to reduce the dimensionality of the feature vectors from the high-dimensional space to 2D, making it possible to visualize the relationships between different data points.
  • The reduced features are then plotted, with different colors representing the training data (T-shirts and Trousers), validation data (T-shirts and Trousers), and OOD data (Sandals).

By visualizing the features generated from different subsets of the data, we can observe how well the models have learned to distinguish between in-distribution classes (T-shirts and Trousers) and handle OOD data (Sandals). This approach allows us to evaluate the robustness and generalization capabilities of the models in dealing with data that may not have been seen during training. ## Standard neural network w/out contrastive learning

1) Train model

We’ll first train our vanilla CNN w/out contrastive learning.

  • Focusing on T-shirts (class 0) and Trousers (class 1)
  • Additionally, we evaluate the models on out-of-distribution (OOD) data, specifically Sandals (class 5)

PYTHON

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, Dataset

# Check if GPU is available and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

PYTHON

# Define a simple CNN model for classification
class ClassificationModel(nn.Module):
    def __init__(self):
        super(ClassificationModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(32 * 28 * 28, 128)
        self.fc2 = nn.Linear(128, 2)  # 2 classes for T-shirts and Trousers

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Load Fashion MNIST dataset and filter for T-shirts and Trousers
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

train_indices = np.where((train_dataset.targets == 0) | (train_dataset.targets == 1))[0]
val_indices = np.where((test_dataset.targets == 0) | (test_dataset.targets == 1))[0]
ood_indices = np.where(test_dataset.targets == 5)[0]

# Use a subset of the data for quicker training
train_subset = Subset(train_dataset, np.random.choice(train_indices, size=5000, replace=False))
val_subset = Subset(test_dataset, np.random.choice(val_indices, size=1000, replace=False))
ood_subset = Subset(test_dataset, np.random.choice(ood_indices, size=1000, replace=False))

train_loader = DataLoader(train_subset, batch_size=256, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=256, shuffle=False)
ood_loader = DataLoader(ood_subset, batch_size=256, shuffle=False)

# Initialize the model and move it to the device
classification_model = ClassificationModel().to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classification_model.parameters(), lr=0.001)

# Training loop for standard neural network
train_losses = []
val_losses = []

for epoch in range(n_epochs):
    total_train_loss = 0
    classification_model.train()
    for batch_images, batch_labels in train_loader:
        batch_images, batch_labels = batch_images.to(device), batch_labels.to(device)

        optimizer.zero_grad()
        outputs = classification_model(batch_images)
        loss = criterion(outputs, batch_labels)

        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

    total_val_loss = 0
    classification_model.eval()
    with torch.no_grad():
        for batch_images, batch_labels in val_loader:
            batch_images, batch_labels = batch_images.to(device), batch_labels.to(device)
            outputs = classification_model(batch_images)
            loss = criterion(outputs, batch_labels)
            total_val_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)
    avg_val_loss = total_val_loss / len(val_loader)
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)

    print(f'Epoch {epoch + 1}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')

PYTHON

# Plot training and validation loss
plt.figure(figsize=(10, 6))
plt.plot(range(1, n_epochs + 1), train_losses, label='Train Loss')
plt.plot(range(1, n_epochs + 1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss - Classification Model')
plt.legend()
plt.show()

2) Extracting learned features

  • After training, we set the models to evaluation mode to prevent updates to the model parameters.
  • For each subset of the data (training, validation, and OOD), we pass the images through the entire network up to the first fully connected layer.
  • The output of this layer, which captures high-level features and abstractions, is then used as a 1D feature vector.
  • These feature vectors are detached from the computational graph and converted to NumPy arrays for further processing.

Why later layer features are better

In both the traditional neural network and the contrastive learning model, we will extract features from the first fully connected layer (fc1) before the final classification layer. Here’s why this layer is particularly suitable for feature extraction:

  • Hierarchical feature representation: In neural networks, the initial layers typically capture low-level features such as edges, textures, and simple shapes (e.g., with CNNs). As you move deeper into the network, the layers capture higher-level, more abstract features that are more relevant for the final classification task. These high-level features are combinations of the low-level features and are typically more discriminative.

  • Better separation of classes: Features from later layers have been transformed through several layers of non-linear activations and pooling operations, making them more suitable for distinguishing between classes. These features are usually more compact and have a better separation in the feature space, which helps in visualization and understanding the model’s decision-making process.

PYTHON

# Extract features using the trained classification model
classification_model.eval()
train_features = []
train_labels_list = []
for batch_images, batch_labels in train_loader:
    batch_images = batch_images.to(device)
    features = classification_model.fc1(classification_model.flatten(classification_model.conv1(batch_images)))
    train_features.append(features.detach().cpu().numpy())
    train_labels_list.append(batch_labels.numpy())

val_features = []
val_labels_list = []
for batch_images, batch_labels in val_loader:
    batch_images = batch_images.to(device)
    features = classification_model.fc1(classification_model.flatten(classification_model.conv1(batch_images)))
    val_features.append(features.detach().cpu().numpy())
    val_labels_list.append(batch_labels.numpy())

ood_features = []
ood_labels_list = []
for batch_images, batch_labels in ood_loader:
    batch_images = batch_images.to(device)
    features = classification_model.fc1(classification_model.flatten(classification_model.conv1(batch_images)))
    ood_features.append(features.detach().cpu().numpy())
    ood_labels_list.append(batch_labels.numpy())

3) Dimensionality Reduction and Visualization:

  • We combine the feature vectors from the training, validation, and OOD data into a single dataset.
  • UMAP (Uniform Manifold Approximation and Projection) is used to reduce the dimensionality of the feature vectors from the high-dimensional space to 2D, making it possible to visualize the relationships between different data points.
  • The reduced features are then plotted, with different colors representing the training data (T-shirts and Trousers), validation data (T-shirts and Trousers), and OOD data (Sandals).

PYTHON

train_features = np.concatenate(train_features)
train_labels = np.concatenate(train_labels_list)
val_features = np.concatenate(val_features)
val_labels = np.concatenate(val_labels_list)
ood_features = np.concatenate(ood_features)
ood_labels = np.concatenate(ood_labels_list)

# Perform UMAP to visualize the classification model features
combined_features = np.vstack([train_features, val_features, ood_features])
combined_labels = np.hstack([train_labels, val_labels, np.full(len(ood_labels), 2)])  # Use 2 for OOD class

umap_reducer = umap.UMAP(n_components=2, random_state=42)
umap_results = umap_reducer.fit_transform(combined_features)

# Split the results back into train, val, and OOD data
umap_train_features = umap_results[:len(train_features)]
umap_val_features = umap_results[len(train_features):len(train_features) + len(val_features)]
umap_ood_features = umap_results[len(train_features) + len(val_features):]

PYTHON

# Plotting UMAP components for classification model
alpha = .2
plt.figure(figsize=(10, 6))
# Plot train T-shirts
scatter1 = plt.scatter(umap_train_features[train_labels == 0, 0], umap_train_features[train_labels == 0, 1], c='blue', alpha=alpha, label='Train T-shirts (ID)')
# Plot train Trousers
scatter2 = plt.scatter(umap_train_features[train_labels == 1, 0], umap_train_features[train_labels == 1, 1], c='red', alpha=alpha, label='Train Trousers (ID)')
# Plot val T-shirts
scatter3 = plt.scatter(umap_val_features[val_labels == 0, 0], umap_val_features[val_labels == 0, 1], c='blue', alpha=alpha, marker='x', label='Val T-shirts (ID)')
# Plot val Trousers
scatter4 = plt.scatter(umap_val_features[val_labels == 1, 0], umap_val_features[val_labels == 1, 1], c='red', alpha=alpha, marker='x', label='Val Trousers (ID)')
# Plot OOD Sandals
scatter5 = plt.scatter(umap_ood_features[:, 0], umap_ood_features[:, 1], c='green', alpha=alpha, marker='o', label='OOD Sandals')
plt.legend(handles=[scatter1, scatter2, scatter3, scatter4, scatter5])
plt.xlabel('First UMAP Component')
plt.ylabel('Second UMAP Component')
plt.title('UMAP of Classification Model Features')
plt.show()

Neural network trained with contrastive learning

What is Contrastive Learning?

Contrastive learning is a technique where the model learns to distinguish between similar and dissimilar pairs of data. This can be achieved through different types of learning: supervised, unsupervised, and self-supervised.

  • Supervised Contrastive Learning: Uses labeled data to create pairs or groups of similar and dissimilar data points based on their labels.

  • Unsupervised Contrastive Learning: Does not use any labels. Instead, it relies on inherent patterns in the data to create pairs. For example, random pairs of data points might be assumed to be dissimilar, while augmented versions of the same data point might be assumed to be similar.

  • Self-Supervised Contrastive Learning: A form of unsupervised learning where the model generates its own supervisory signal from the data. This typically involves data augmentation techniques where positive pairs are created by augmenting the same image (e.g., cropping, rotating), and negative pairs are formed from different images.

In contrastive learning, the model learns to bring similar pairs closer in the embedding space while pushing dissimilar pairs further apart. This approach is particularly useful for tasks like image retrieval, clustering, and representation learning.

Certainly! Let’s expand on how we are treating the T-shirt, Trouser, and Sandals classes in the context of our supervised contrastive learning framework.

Key Concepts in Our Code

Data Preparation

  • Dataset: We use the Fashion MNIST dataset, which contains images of various clothing items, each labeled with a specific class.
  • Class Filtering: For this exercise, we are focusing on three classes from the Fashion MNIST dataset:
    • T-shirts (class label 0)
    • Trousers (class label 1)
    • Sandals (class label 5)
  • In-Distribution (ID) Data: We treat T-shirts and Trousers as our primary classes for training. These are considered “in-distribution” data.
  • Out-of-Distribution (OOD) Data: Sandals are treated as a different class for testing the robustness of our learned embeddings, making them “out-of-distribution” data.

Pairs Creation

For each image in our training set: - Positive Pair: We find another image of the same class (either T-shirt or Trouser). These pairs are labeled as similar. - Negative Pair: We randomly choose an image from a different class (T-shirt paired with Trouser or vice versa). These pairs are labeled as dissimilar.

By creating these pairs, the model learns to produce embeddings where similar images (same class) are close together, and dissimilar images (different classes) are farther apart.

Model Architecture

The model is a simple Convolutional Neural Network (CNN) designed to output embeddings. It consists of: - Two convolutional layers to extract features from the images. - Fully connected layers to map these features to a 50-dimensional embedding space.

Training Process

  • Forward Pass: The model processes pairs of images and outputs their embeddings.
  • Contrastive Loss: We use a contrastive loss function to train the model. This loss encourages embeddings of similar pairs to be close and embeddings of dissimilar pairs to be far apart. Specifically, we:
    • Normalize the embeddings.
    • Calculate similarity scores.
    • Compute the contrastive loss, which penalizes similar pairs if they are not close enough and dissimilar pairs if they are too close.

Differences from Standard Neural Network Training

  • Data Pairing: In contrastive learning, we create pairs of data points. Standard neural network training typically involves individual data points with corresponding labels.
  • Loss Function: We use a contrastive loss function instead of the typical cross-entropy loss used in classification tasks. The contrastive loss is designed to optimize the relative distances between pairs of embeddings.
  • Supervised Learning: Our approach uses labeled data to form similar and dissimilar pairs, making it supervised contrastive learning. This contrasts with self-supervised or unsupervised methods where labels are not used.

Specific Type of Contrastive Learning

The specific contrastive learning technique we are using here is a form of supervised contrastive learning. This involves using labeled data to create similar and dissimilar pairs of images. The model is trained to output embeddings where a contrastive loss function is applied to these pairs. By doing so, the model learns to map images into an embedding space where similar images are close together, and dissimilar images are farther apart.

By training with this method, the model learns robust feature representations that are useful for various downstream tasks, even with limited labeled data. This is powerful because it allows leveraging labeled data to improve the model’s performance and generalizability.

Application of the Framework

  1. Training with In-Distribution Data:
    • T-shirts and Trousers: These classes are used to train the model. Positive and negative pairs are created within this subset to teach the model to distinguish between the two classes.
  2. Testing with Out-of-Distribution Data:
    • Sandals: This class is used to test the robustness of the embeddings learned by the model. By introducing a completely different class during testing, we can evaluate how well the model generalizes to new, unseen data.

This framework demonstrates how supervised contrastive learning can be effectively applied to learn discriminative embeddings that can generalize well to both in-distribution and out-of-distribution data.

PYTHON

import torch
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import umap
import torch.nn as nn
import torch.optim as optim

class PairDataset(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img1 = self.images[idx]
        label1 = self.labels[idx]
        idx2 = np.random.choice(np.where(self.labels == label1)[0])
        img2 = self.images[idx2]
        return img1, img2, label1

# Load Fashion MNIST dataset and filter for T-shirts and Trousers
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

train_indices = np.where((train_dataset.targets == 0) | (train_dataset.targets == 1))[0]
val_indices = np.where((test_dataset.targets == 0) | (test_dataset.targets == 1))[0]
ood_indices = np.where(test_dataset.targets == 5)[0]

# Use a subset of the data for quicker training
train_subset = Subset(train_dataset, np.random.choice(train_indices, size=5000, replace=False))
val_subset = Subset(test_dataset, np.random.choice(val_indices, size=1000, replace=False))
ood_subset = Subset(test_dataset, np.random.choice(ood_indices, size=1000, replace=False))

# Create DataLoaders for the subsets
train_images = np.array([train_dataset[i][0].numpy() for i in train_indices])
train_labels = np.array([train_dataset[i][1] for i in train_indices])
val_images = np.array([test_dataset[i][0].numpy() for i in val_indices])
val_labels = np.array([test_dataset[i][1] for i in val_indices])
ood_images = np.array([test_dataset[i][0].numpy() for i in ood_indices])
ood_labels = np.array([test_dataset[i][1] for i in ood_indices])

train_loader = DataLoader(PairDataset(train_images, train_labels), batch_size=256, shuffle=True)
val_loader = DataLoader(PairDataset(val_images, val_labels), batch_size=256, shuffle=False)
ood_loader = DataLoader(PairDataset(ood_images, ood_labels), batch_size=256, shuffle=False)

# Inspect the data loaders
for batch_images1, batch_images2, batch_labels in train_loader:
    print(f"train_loader batch_images1 shape: {batch_images1.shape}")
    print(f"train_loader batch_images2 shape: {batch_images2.shape}")
    print(f"train_loader batch_labels shape: {batch_labels.shape}")
    break

for batch_images1, batch_images2, batch_labels in val_loader:
    print(f"val_loader batch_images1 shape: {batch_images1.shape}")
    print(f"val_loader batch_images2 shape: {batch_images2.shape}")
    print(f"val_loader batch_labels shape: {batch_labels.shape}")
    break

for batch_images1, batch_images2, batch_labels in ood_loader:
    print(f"ood_loader batch_images1 shape: {batch_images1.shape}")
    print(f"ood_loader batch_images2 shape: {batch_images2.shape}")
    print(f"ood_loader batch_labels shape: {batch_labels.shape}")
    break

PYTHON

# Define a simple CNN model for contrastive learning
class ContrastiveModel(nn.Module):
    def __init__(self):
        super(ContrastiveModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(32 * 28 * 28, 128)
        self.fc2 = nn.Linear(128, 50)  # Embedding size

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Define contrastive loss function
def contrastive_loss(z_i, z_j, temperature=0.5):
    z_i = nn.functional.normalize(z_i, dim=1)
    z_j = nn.functional.normalize(z_j, dim=1)
    batch_size = z_i.size(0)
    z = torch.cat([z_i, z_j], dim=0)

    sim = torch.mm(z, z.t()) / temperature
    sim_i_j = torch.diag(sim, batch_size)
    sim_j_i = torch.diag(sim, -batch_size)

    positives = torch.cat([sim_i_j, sim_j_i], dim=0)
    negatives_mask = ~torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
    negatives = sim[negatives_mask].view(2 * batch_size, -1)

    loss = -torch.mean(positives) + torch.mean(negatives)
    return loss

# Training loop for contrastive learning
def train_contrastive_model(model, train_loader, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for img1, img2, _ in train_loader:
            img1, img2 = img1.to(device), img2.to(device)

            optimizer.zero_grad()

            z_i = model(img1)
            z_j = model(img2)

            loss = contrastive_loss(z_i, z_j)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")

# Instantiate the model, optimizer, and start training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
contrastive_model = ContrastiveModel().to(device)
optimizer = optim.Adam(contrastive_model.parameters(), lr=0.001)

train_contrastive_model(contrastive_model, train_loader, optimizer, num_epochs=n_epochs)

2) Extracting learned features

  • After training, we set the models to evaluation mode to prevent updates to the model parameters.
  • For each subset of the data (training, validation, and OOD), we pass the images through the entire network up to the first fully connected layer.
  • The output of this layer, which captures high-level features and abstractions, is then used as a 1D feature vector.
  • These feature vectors are detached from the computational graph and converted to NumPy arrays for further processing.

PYTHON

# Extract features using the trained contrastive model
contrastive_model.eval()
train_features = []
train_labels_list = []
for img1, _, label1 in train_loader:
    img1 = img1.to(device)
    features = contrastive_model.fc1(contrastive_model.flatten(contrastive_model.conv1(img1)))
    train_features.append(features.detach().cpu().numpy())
    train_labels_list.append(label1.numpy())

val_features = []
val_labels_list = []
for img1, _, label1 in val_loader:
    img1 = img1.to(device)
    features = contrastive_model.fc1(contrastive_model.flatten(contrastive_model.conv1(img1)))
    val_features.append(features.detach().cpu().numpy())
    val_labels_list.append(label1.numpy())

ood_features = []
ood_labels_list = []
for img1, _, label1 in ood_loader:
    img1 = img1.to(device)
    features = contrastive_model.fc1(contrastive_model.flatten(contrastive_model.conv1(img1)))
    ood_features.append(features.detach().cpu().numpy())
    ood_labels_list.append(label1.numpy())

train_features = np.concatenate(train_features)
train_labels = np.concatenate(train_labels_list)
val_features = np.concatenate(val_features)
val_labels = np.concatenate(val_labels_list)
ood_features = np.concatenate(ood_features)
ood_labels = np.concatenate(ood_labels_list)

# Diagnostic print statements
print(f"train_features shape: {train_features.shape}")
print(f"train_labels shape: {train_labels.shape}")
print(f"val_features shape: {val_features.shape}")
print(f"val_labels shape: {val_labels.shape}")
print(f"ood_features shape: {ood_features.shape}")
print(f"ood_labels shape: {ood_labels.shape}")

3) Dimensionality Reduction and Visualization:

  • We combine the feature vectors from the training, validation, and OOD data into a single dataset.
  • UMAP (Uniform Manifold Approximation and Projection) is used to reduce the dimensionality of the feature vectors from the high-dimensional space to 2D, making it possible to visualize the relationships between different data points.
  • The reduced features are then plotted, with different colors representing the training data (T-shirts and Trousers), validation data (T-shirts and Trousers), and OOD data (Sandals).

PYTHON

# Ensure the labels array for OOD matches the feature array length
combined_features = np.vstack([train_features, val_features, ood_features])
combined_labels = np.hstack([train_labels, val_labels, np.full(len(ood_features), 2)])  # Use 2 for OOD class

umap_reducer = umap.UMAP(n_components=2, random_state=42)
umap_results = umap_reducer.fit_transform(combined_features)

# Split the results back into train, val, and OOD data
umap_train_features = umap_results[:len(train_features)]
umap_val_features = umap_results[len(train_features):len(train_features) + len(val_features)]
umap_ood_features = umap_results[len(train_features) + len(val_features):]

# Plotting UMAP components for contrastive learning model
plt.figure(figsize=(10, 6))
# Plot train T-shirts
scatter1 = plt.scatter(umap_train_features[train_labels == 0, 0], umap_train_features[train_labels == 0, 1], c='blue', alpha=0.5, label='Train T-shirts (ID)')
# Plot train Trousers
scatter2 = plt.scatter(umap_train_features[train_labels == 1, 0], umap_train_features[train_labels == 1, 1], c='red', alpha=0.5, label='Train Trousers (ID)')
# Plot val T-shirts
scatter3 = plt.scatter(umap_val_features[val_labels == 0, 0], umap_val_features[val_labels == 0, 1], c='blue', alpha=0.5, marker='x', label='Val T-shirts (ID)')
# Plot val Trousers
scatter4 = plt.scatter(umap_val_features[val_labels == 1, 0], umap_val_features[val_labels == 1, 1], c='red', alpha=0.5, marker='x', label='Val Trousers (ID)')
# Plot OOD Sandals
scatter5 = plt.scatter(umap_ood_features[:, 0], umap_ood_features[:, 1], c='green', alpha=0.5, marker='o', label='OOD Sandals')
plt.legend(handles=[scatter1, scatter2, scatter3, scatter4, scatter5])
plt.xlabel('First UMAP Component')
plt.ylabel('Second UMAP Component')
plt.title('UMAP of Contrastive Model Features')
plt.show()

Limitations of Threshold-Based OOD Detection Methods

Threshold-based out-of-distribution (OOD) detection methods are widely used due to their simplicity and intuitive nature. However, they come with several significant limitations that need to be considered:

  1. Dependence on OOD Data Choice:
    • Variety and Representation: The effectiveness of threshold-based methods heavily relies on the variety and representativeness of the OOD data used during threshold selection. If the chosen OOD samples do not adequately cover the possible range of OOD scenarios, the threshold may not generalize well to unseen OOD data.
    • Threshold Determination: To determine a robust threshold, it is essential to include a diverse set of OOD samples. This helps in setting a threshold that can effectively distinguish between in-distribution and out-of-distribution data across various scenarios. Without a comprehensive OOD dataset, the threshold might either be too conservative, causing many ID samples to be misclassified as OOD, or too lenient, failing to detect OOD samples accurately.
  2. Impact of High Thresholds:
    • False OOD Classification: High thresholds can lead to a significant number of ID samples being incorrectly classified as OOD. This false OOD classification results in the loss of potentially valuable data, reducing the efficiency and performance of the model.
    • Data Efficiency: In applications where retaining as much ID data as possible is crucial, high thresholds can be particularly detrimental. It’s important to strike a balance between detecting OOD samples and retaining ID samples to ensure the model’s overall performance and data efficiency.
  3. Sensitivity to Model Confidence:
    • Model Calibration: Threshold-based methods rely on the model’s confidence scores, which can be misleading if the model is poorly calibrated. Overconfident predictions for ID samples or underconfident predictions for OOD samples can result in suboptimal threshold settings.
    • Confidence Variability: The variability in confidence scores across different models and architectures can make it challenging to set a universal threshold. Each model might require different threshold settings, complicating the deployment and maintenance of threshold-based OOD detection systems.
  4. Lack of Discriminative Features:
    • Boundary-Based Detection: Threshold-based methods focus on class boundaries rather than learning discriminative features that can effectively separate ID and OOD samples. This approach can be less robust, particularly in complex or high-dimensional data spaces where class boundaries might be less clear.
    • Feature Learning: By relying solely on confidence scores, these methods miss the opportunity to learn and leverage features that are inherently more discriminative. This limitation highlights the need for advanced techniques like contrastive learning, which focuses on learning features that distinguish between ID and OOD samples more effectively.

Conclusion

While threshold-based OOD detection methods offer a straightforward approach, their limitations underscore the importance of considering additional OOD samples for robust threshold determination and the potential pitfalls of high thresholds. Transitioning to methods that learn discriminative features rather than relying solely on class boundaries can address these limitations, paving the way for more effective OOD detection. This sets the stage for discussing contrastive learning, which provides a powerful framework for learning such discriminative features.

Content from OOD detection: training-time regularization


Last updated on 2024-11-15 | Edit this page

Estimated time: 0 minutes

Overview

Questions

  • What are the key considerations when designing algorithms for OOD detection?
  • How can OOD detection be incorporated into the loss functions of models?
  • What are the challenges and best practices for training models with OOD detection capabilities?

Objectives

  • Understand the critical design considerations for creating effective OOD detection algorithms.
  • Learn how to integrate OOD detection into the loss functions of machine learning models.
  • Identify the challenges in training models with OOD detection and explore best practices to overcome these challenges.

Training-time regularization for OOD detection

Training-time regularization methods improve OOD detection by incorporating penalties into the training process. These penalties encourage the model to handle OOD data effectively, either by:

  • Penalizing high confidence on OOD samples,
  • Optimizing feature representations to separate ID and OOD data,
  • Or enhancing robustness to adversarial or ambiguous inputs.

The following methods apply these penalties in different ways: outlier exposure, contrastive learning, confidence penalties, and adversarial training.

2a) Outlier exposure

Outlier Exposure (OE) penalizes high confidence on OOD samples by introducing auxiliary datasets during training. This method teaches the model to differentiate OOD data from ID data.

How it works:

  • Use a curated auxiliary dataset of OOD samples that differ from the training distribution.
  • Augment the training loss function to penalize high confidence on these auxiliary samples.
  • Resulting models are less likely to misclassify OOD inputs as ID.
Advantages Limitations
Simple to implement when auxiliary datasets are available. Requires access to high-quality, diverse OOD datasets during training.
Improves OOD detection performance without significant computational cost. Performance may degrade for OOD samples dissimilar to the auxiliary dataset.

2b) Contrastive learning

Contrastive learning optimizes feature representations by applying penalties that control the similarity of embeddings. Positive pairs (similar samples) are brought closer together, while negative pairs (dissimilar samples) are pushed apart. This results in a feature space where OOD data is less likely to overlap with ID data.

How it works:

  • Define a contrastive loss that minimizes the distance between embeddings of similar samples (e.g., belonging to the same class).
  • Simultaneously maximize the distance between embeddings of dissimilar samples (e.g., ID vs. synthetic or auxiliary OOD samples).
  • Often uses data augmentation or self-supervised techniques to generate “positive” and “negative” pairs.
Advantages Limitations
Does not require labeled auxiliary OOD data, as augmentations or unsupervised data can be used. Computationally expensive, especially with large datasets.
Improves the quality of learned representations, benefiting other tasks. Requires careful tuning of the contrastive loss and data augmentation strategy.

2c) Other regularization-based techniques

Other methods incorporate penalties directly into the training process to improve robustness to OOD data:

  • Confidence penalties: Penalize overconfidence in predictions, especially on ambiguous samples.
  • Adversarial training: Generate adversarial examples (slightly perturbed ID samples) to penalize high confidence on these perturbed examples, improving robustness.
Advantages Limitations
Enhances OOD detection performance by integrating it into the training process. Requires careful design of the training procedure and loss function.
Leads to better generalization for both ID and OOD scenarios. Computationally intensive and may need access to additional datasets.

Summary of Training-Time Regularization Methods

Method Penalty Applied Advantages Limitations
Outlier Exposure High confidence on auxiliary OOD data. Simple to implement, improves performance. Requires high-quality auxiliary datasets, may not generalize to unseen OOD data.
Contrastive Learning Embedding similarity for dissimilar samples (and vice versa) Improves feature space quality, versatile. Computationally expensive, requires careful tuning.
Confidence Penalties Overconfidence on ambiguous inputs. Improves robustness, generalizes well. Requires careful design, computationally intensive.
Adversarial Training High confidence on adversarial examples. Enhances robustness to perturbed inputs. Computationally intensive, challenging to implement.

Key Points

  • Training-time regularization enhances OOD detection by incorporating techniques like Outlier Exposure and Contrastive Learning during model training.

Content from Documenting and releasing a model


Last updated on 2024-09-25 | Edit this page

Estimated time: 0 minutes

Overview

Questions

  • Why is model sharing important in the context of reproducibility and responsible use?
  • What are the challenges, risks, and ethical considerations related to sharing models?
  • How can model-sharing best practices be applied using tools like model cards and the Hugging Face platform?

Objectives

  • Understand the importance of model sharing and best practices to ensure reproducibility and responsible use of models.
  • Understand the challenges, risks, and ethical concerns associated with model sharing.
  • Apply model-sharing best practices through using model cards and the Hugging Face platform.

Key Points

  • Model cards are the standard technique for communicating information about how machine learning systems were trained and how they should and should not be used.
  • Models can be shared and reused via the Hugging Face platform.

Why should we share trained models?

Discuss in small groups and report out: Why do you believe it is or isn’t important to share ML models? How has model-sharing contributed to your experiences or projects?

  • Accelerating research: Sharing models allows researchers and practitioners to build upon existing work, accelerating the pace of innovation in the field.
  • Knowledge exchange: Model sharing promotes knowledge exchange and collaboration within the machine learning community, fostering a culture of open science.
  • Reproducibility: Sharing models, along with associated code and data, enhances reproducibility, enabling others to validate and verify the results reported in research papers.
  • Benchmarking: Shared models serve as benchmarks for comparing new models and algorithms, facilitating the evaluation and improvement of state-of-the-art techniques.
  • Education / Accessibility to state-of-the-art architectures: Shared models provide valuable resources for educational purposes, allowing students and learners to explore and experiment with advanced machine learning techniques.
  • Repurpose (transfer learning and finetuning): Some models (i.e., foundation models) can be repurposed for a wide variety of tasks. This is especially useful when working with limited data. Data scarcity
  • Resource efficiency: Instead of training a model from the ground up, practitioners can use existing models as a starting point, saving time, computational resources, and energy.

Challenges and risks of model sharing

Discuss in small groups and report out: What are some potential challenges, risks, or ethical concerns associated with model sharing and reproducing ML workflows?

  • Privacy concerns: Sharing models that were trained on sensitive or private data raises privacy concerns. The potential disclosure of personal information through the model poses a risk to individuals and can lead to unintended consequences.
  • Informed consent: If models involve user data, ensuring informed consent is crucial. Sharing models trained on user-generated content without clear consent may violate privacy norms and regulations.
  • Intellectual property: Models may be developed within organizations with proprietary data and methodologies. Sharing such models without proper consent or authorization may lead to intellectual property disputes and legal consequences.
  • Model robustness and generalization: Reproduced models may not generalize well to new datasets or real-world scenarios. Failure to account for the limitations of the original model can result in reduced performance and reliability in diverse settings.
  • Lack of reproducibility: Incomplete documentation, missing details, or changes in dependencies over time can hinder the reproducibility of ML workflows. This lack of reproducibility can impede scientific progress and validation of research findings.
  • Unintended use and misuse: Shared models may be used in unintended ways, leading to ethical concerns. Developers should consider the potential consequences of misuse, particularly in applications with societal impact, such as healthcare or law enforcement.
  • Responsible AI considerations: Ethical considerations, such as fairness, accountability, and transparency, should be addressed during model sharing. Failing to consider these aspects can result in models that inadvertently discriminate or lack interpretability. Models used for decision-making, especially in critical areas like healthcare or finance, should be ethically deployed. Transparent documentation and disclosure of how decisions are made are essential for responsible AI adoption.

Saving the model locally


Let’s review the simplest method for sharing a model first — saving the model locally. When working with PyTorch, it’s important to know how to save and load models efficiently. This process ensures that you can pause your work, share your models, or deploy them for inference without having to retrain them from scratch each time.

Defining the model

As an example, we’ll configure a simple perceptron (single hidden layer) in PyTorch. We’ll define a bare bones class for this just so we can initialize the model.

PYTHON

from typing import Dict, Any
import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self, config: Dict[str, int]):
        super().__init__()
        # Parameter is a trainable tensor initialized with random values
        self.param = nn.Parameter(torch.rand(config["num_channels"], config["hidden_size"]))
        # Linear layer (fully connected layer) for the output
        self.linear = nn.Linear(config["hidden_size"], config["num_classes"])
        # Store the configuration
        self.config = config

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Forward pass: Add the input to the param tensor, then pass through the linear layer
        return self.linear(x + self.param)

Initialize model by calling the class with configuration settings.

PYTHON

# Create model instance with specific configuration
config = {"num_channels": 3, "hidden_size": 32, "num_classes": 10}
model = MyModel(config=config)

We can then write a function to save out the model. We’ll need both the model weights and the model’s configuration (hyperparameter settings). We’ll save the configurations as a json since a key/value format is convenient here.

PYTHON

import json

# Function to save model and config locally
def save_model(model: nn.Module, model_path: str, config_path: str) -> None:
    # Save model state dict (weights and biases) as a .pth file
    torch.save(model.state_dict(), model_path) #
    # Save config
    with open(config_path, 'w') as f:
        json.dump(model.config, f)

PYTHON

# Save the model and config locally
save_model(model, "my_awesome_model.pth", "my_awesome_model_config.json")

To load the model back in, we can write another function

PYTHON

# Function to load model and config locally
def load_model(model_class: Any, model_path: str, config_path: str) -> nn.Module:
    # Load config
    with open(config_path, 'r') as f:
        config = json.load(f)
    # Create model instance with config
    model = model_class(config=config)
    # Load model state dict
    model.load_state_dict(torch.load(model_path))
    return model

PYTHON

# Load the model and config locally
loaded_model = load_model(MyModel, "my_awesome_model.pth", "my_awesome_model_config.json")

# Verify the loaded model
print(loaded_model)

Saving a model to Hugging Face


To share your model with a wider audience, we recommend uploading your model to Hugging Face. Hugging Face is a very popular machine learning (ML) platform and community that helps users build, deploy, share, and train machine learning models. It has quickly become the go-to option for sharing models with the public.

Create a Hugging Face account and access Token

If you haven’t completed these steps from the setup, make sure to do this now.

Create account: To create an account on Hugging Face, visit: huggingface.co/join. Enter an email address and password, and follow the instructions provided via Hugging Face (you may need to verify your email address) to complete the process.

Setup access token: Once you have your account created, you’ll need to generate an access token so that you can upload/share models to your Hugging Face account during the workshop. To generate a token, visit the Access Tokens setting page after logging in.

Login to Hugging Face account

To login, you will need to retrieve your access token from the Access Tokens setting page

PYTHON

!huggingface-cli login

You might get a message saying you cannot authenticate through git-credential as no helper is defined on your machine. This warning message should not stop you from being able to complete this episode, but it may mean that the token won’t be stored on your machine for future use.

Once logged in, we will need to edit our model class definition to include Hugging Face’s “push_to_hub” attribute. To enable the push_to_hub functionality, you’ll need to include the PyTorchModelHubMixin “mixin class” provided by the huggingface_hub library. A mixin class is a type of class used in object-oriented programming to “mix in” additional properties and methods into a class. The PyTorchModelHubMixin class adds methods to your PyTorch model to enable easy saving and loading from the Hugging Face Model Hub.

Here’s how you can adjust the code to incorporate both saving/loading locally and pushing the model to the Hugging Face Hub.

PYTHON

from huggingface_hub import PyTorchModelHubMixin # NEW

class MyModel(nn.Module, PyTorchModelHubMixin): # PyTorchModelHubMixin is new
    def __init__(self, config: Dict[str, Any]):
        super().__init__()
        # Initialize layers and parameters
        self.param = nn.Parameter(torch.rand(config["num_channels"], config["hidden_size"]))
        self.linear = nn.Linear(config["hidden_size"], config["num_classes"])
        self.config = config

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x + self.param)

PYTHON

# Create model instance with specific configuration
config = {"num_channels": 3, "hidden_size": 32, "num_classes": 10}
model = MyModel(config=config)
print(model)

PYTHON

# push to the hub
model.push_to_hub("my-awesome-model", config=config)

Verifying: To check your work, head back over to your Hugging Face account and click your profile icon in the top-right of the website. Click “Profile” from there to view all of your uploaded models. Alternatively, you can search for your username (or model name) from the Model Hub.

Loading the model from Hugging Face

PYTHON

# reload
model = MyModel.from_pretrained("your-username/my-awesome-model")

Uploading transformer models to Hugging Face


Key Differences

  • Saving and Loading the Tokenizer: Transformer models require a tokenizer that needs to be saved and loaded with the model. This is not necessary for custom PyTorch models that typically do not require a separate tokenizer.
  • Using Pre-trained Classes: Transformer models use classes like AutoModelForSequenceClassification and AutoTokenizer from the transformers library, which are pre-built and designed for specific tasks (e.g., sequence classification).
  • Methods for Saving and Loading: The transformers library provides save_pretrained and from_pretrained methods for both models and tokenizers, which handle the serialization and deserialization processes seamlessly.

PYTHON

from transformers import AutoModelForSequenceClassification, AutoTokenizer

# Load a pre-trained model and tokenizer
model_name = "bert-base-uncased"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Save the model and tokenizer locally
model.save_pretrained("my_transformer_model")
tokenizer.save_pretrained("my_transformer_model")

# Load the model and tokenizer from the saved directory
loaded_model = AutoModelForSequenceClassification.from_pretrained("my_transformer_model")
loaded_tokenizer = AutoTokenizer.from_pretrained("my_transformer_model")

# Verify the loaded model and tokenizer
print(loaded_model)
print(loaded_tokenizer)

PYTHON

# Push the model and tokenizer to Hugging Face Hub
model.push_to_hub("my-awesome-transformer-model")
tokenizer.push_to_hub("my-awesome-transformer-model")

# Load the model and tokenizer from the Hugging Face Hub
hub_model = AutoModelForSequenceClassification.from_pretrained("user-name/my-awesome-transformer-model")
hub_tokenizer = AutoTokenizer.from_pretrained("user-name/my-awesome-transformer-model")

# Verify the model and tokenizer loaded from the hub
print(hub_model)
print(hub_tokenizer)

What pieces must be well-documented to ensure reproducible and responsible model sharing?

Discuss in small groups and report out: What type of information needs to be included in the documentation when sharing a model?

  • Environment setup
  • Training data
    • How the data was collected
    • Who owns the data: data license and usage terms
    • Basic descriptive statistics: number of samples, features, classes, etc.
    • Note any class imbalance or general bias issues
    • Description of data distribution to help prevent out-of-distribution failures.
  • Preprocessing steps.
    • Data splitting
    • Standardization method
    • Feature selection
    • Outlier detection and other filters
  • Model architecture, hyperparameters and, training procedure (e.g., dropout or early stopping)
  • Model weights
  • Evaluation metrics. Results and performance. The more tasks/datasets you can evaluate on, the better.
  • Ethical considerations: Include investigations of bias/fairness when applicable (i.e., if your model involves human data or affects decision-making involving humans)
  • Contact info
  • Acknowledgments
  • Examples and demos (highly recommended)

Document your model

For this challenge, you have two options:

  1. Start writing a model card for a model you have created for your research. The solution from the previous challenge or this template from Hugging Face are good places to start, but note that not all fields may be relevant, depending on what your model does.

  2. Find a model on HuggingFace that has a model card, for example, you could search for models using terms like “sentiment classification” or “medical”. Read the model card and evaluate whether the information is clear and complete. Would you be able to recreate the model based on the information presented? Do you feel that there is enough information to be able to evaluate you would be able to adapt this model for your purposes? You can refer to the previous challenge’s solution for ideas of what information should be included, but note that not all sections are relevant to all models.

Pair up with a classmate and discuss what you wrote/read. Do model cards seem like a useful tool for you moving forwards?