Validating LLM Responses: Pydantic & Instructor Integration with LLMs (Part II)
Getting structured output from LLMs is a notoriously tricky problem to solve given their non-deterministic nature. In our previous post, we explored using Pydantic and Instructor to ensure reliable JSON output from LLMs. Now, in Part 2, we dive deeper into Pydantic's validation capabilities. By setting up max_retries, we give our model multiple chances to produce accurate responses. If validations fail or data is improperly structured, Instructor handles retries until we get the desired result. This powerful feature is seamlessly managed with the Instructor patch for our OpenAI client.
Implementing Validation with Pydantic
Pydantic gives us a handy way to check if the responses we get are correct. It looks at the structure and types of data we expect and makes sure everything matches up. Plus, it goes beyond that by letting us add extra checks. So, if the LLM's response doesn't quite fit the bill, we can ask it to try again using maxRetries. This helps the model learn and get better at giving us the right information.
Let's go back to our example where we wanted to pull out important details from patient data. Here, we'll see how Pydantic, along with some smart validation tricks, can make sure we always get reliable and useful answers.
Implementing Detailed Validations
We can add specific checks to our response model. Since our output is organized like a structured form, adding checks is as easy as setting rules to follow.
Let's simplify this with an example about age validation. We want to ensure that the age entered is reasonable, so we'll make a rule that the age must be under 200. If it's not, we'll raise a ValueError
error saying, "The age should be less than 200."
def validate_age(age: int) -> Optional[int]:
if age < 200:
return age
else:
raise ValueError("Age must be less than 200")
Now, let's include this rule in our class definition using Pydantic.
from pydantic import BaseModel, validator
from typing_extensions import Annotated, Literal
class PatientInfo(BaseModel):
sex: Literal['M', 'F']
age: Annotated[int, validator(validate_age)]
geographical_region: str
This ensures that the age given meets our standards after it's added to the model.
Additionally, the max_retries setting is crucial. It's similar to giving our model multiple attempts to provide the correct response. If our validations fail or our data isn't structured correctly, the instructor can send the error message back to the model several times until we receive the desired answer. Thankfully, managing this feature is simplified with the instructor patch for our OpenAI client.
This blog is by Harsh, a full-stack developer, who loves sci-fi and drinking black coffee. If you like this post, try KushoAI today, and start shipping bug-free code faster!
Using LLM Validators:
Instructor offers a handy way to use Language Model Models (LLMs) to check if certain fields or logic are correct. Let's explain this with a simple example. We want to make sure that the locations mentioned are actual body parts.
from instructor import llm_validator
from pydantic import BeforeValidator
class Symptoms(BaseModel):
description: str
pain_type: str
locations: List[
str,
]
intensity: int
location_precision: int
pace: int
class MedicalHistory(BaseModel):
pathology: Annotated[
str,
BeforeValidator(
llm_validator(
"must be a valid disease name", openai_client=instructor_openai_client
)
),
]
symptoms: Symptoms
increase_with_exertion: bool
alleviate_with_rest
class PatientData(BaseModel):
patient_info: PatientInfo
medical_history: MedicalHistory
risk_factors: RiskFactors
differential_diagnosis: List[DifferentialDiagnosis]
This means we're using a smart tool to double-check if the disease names we're given are accurate. It's like having a helpful assistant who makes sure everything makes sense before we move forward.
Dynamic Validation Based on Runtime Data
Making Validation Flexible with Real-Time Data
Sometimes, making sure we grab all the right information from our data can be tricky, especially when the rules for what's right change depending on what we're looking at. This happens a lot when we're pulling out details from data used in programs, like APIs, where we need to be flexible to handle different kinds of requests.
Let's imagine a situation where we're pulling out details from a JSON file. We want to make sure we grab everything we need, and the rules for what we need might change depending on what's in the JSON. This was a problem I faced when I was trying to gather detailed information about medical diagnoses from an AI assistant.
Imagine we want to improve our diagnosis system. We don't just want the name of the diagnosis and how likely it is. We also want a clear explanation of what the diagnosis means and what treatments are possible.
from pydantic import create_model, Field, List
class Treatment(BaseModel):
treatment_name: str
treatment_procedure: str
comments: str
class Diagnosis(BaseModel):
diagnosis_name: str
diagnosis_description: str = Field(
...,
description="A professional but human way to describe the diagnosis and its effect."
)
probability: int
possible_treatment: List[Treatment]
description: str = Field(
...,
description="A Professional but sympathetic description of its treatment and its validity regarding other treatments according to you.",
)
def update_models_description(possible_diagnosis: List[DifferentialDiagnosis]):
# Create a new model with the updated description
diagnosis = ",".join([str((diag.disease_name, diag.probability)) for diag in possible_diagnosis])
print(diagnosis)
updated_model = create_model(
"DiagnosisAndTreatments",
disease=(
List[Diagnosis],
Field(
...,
description=f"Ensure that all the following possible information is present for all the following diagnoses: {diagnosis}"
),
),
)
return updated_model
completion = instructor_openai_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": f"You are an AI doctor assistance. Please convert the following patient data to extrapolate the following data ${medical_info}",
}
],
max_retries=3,
response_model=update_models_description(medical_data.differential_diagnosis),
)
print(type(completion))
print(json.dumps(completion.model_dump(), indent=1))
<class '__main__.DiagnosisAndTreatments'>
{
"disease": [
{
"diagnosis_name": "Unstable angina",
"diagnosis_description": "Unstable angina is a condition where the blood flow to the heart is suddenly blocked, leading to chest pain and discomfort. This condition can be life-threatening and requires immediate medical attention.",
"probability": 26,
"possible_treatment": [
{
"treatment_name": "Coronary angiography",
"treatment_procedure": "This procedure uses contrast dye and X-rays to see inside the arteries of the heart. It can help identify blockages and other problems in the blood vessels around the heart.",
"comments": "Coronary angiography is a common and effective procedure to diagnose and treat heart conditions."
},
{
"treatment_name": "Medications",
"treatment_procedure": "Medications such as nitroglycerin and beta-blockers may be prescribed to manage symptoms and prevent future episodes of unstable angina.",
"comments": "Medications play a crucial role in controlling symptoms and improving the prognosis of unstable angina patients."
}
],
"description": "Unstable angina is a serious condition that requires prompt diagnosis and treatment to prevent heart damage and complications."
},
{
"diagnosis_name": "Stable angina",
...
},
]
}
In this setup, we're cleverly building a Pydantic model on the fly, depending on the data we're dealing with about possible diagnoses. This means our checks are customized to fit each diagnosis – making sure we cover all the bases.
Plus, by adding Pydantic checks to our model, we make sure we don't miss anything important for each diagnosis. Whether it's the name, description, likelihood, or treatment options, we're double-checking everything to make sure our data is complete and accurate.
Member discussion