From Classification to Connectivity: Generating Liquid-Handling Workflows with GNNs
Modern Graph Neural Networks (GNNs) excel at predicting node and edge attributes, but many practical problems require changing the graph itself. Liquid-handling protocols are a prime example: executing a protocol means constructing a sequence of transfers that incrementally grows a workflow graph while respecting hard physical and chemical constraints. This post sketches how to adapt GNNs from attribute prediction to connectivity generation for liquid handling.
Background: GNN building blocks
If you are new to GNNs, I recommend the clear and interactive overview in Distill’s “A Gentle Introduction to Graph Neural Networks” link. It explains message passing, aggregation and update functions, and how information flows over graph structure.
Key takeaways for our setting:
- GNNs operate over nodes, edges, and (optionally) global features via permutation-invariant aggregation.
- Information is localized and propagates by hops, which is useful for enforcing local constraints (e.g., volumes in a well or sterility across edges) while letting global features (e.g., temperature, instrument state) influence decisions.
Problem framing: protocols as dynamic DAGs
We represent a liquid-handling protocol as a dynamic, directed acyclic multigraph (DAG) over time steps t = 0..T. The DAG constraint is fundamental: liquid handling operations cannot create cycles because time flows forward and reagents cannot be “un-mixed” or “un-transferred.”
- Nodes: containers/wells, instrument resources (tips, reservoirs), intermediate mixtures, deck locations.
- Edges: operations such as aspirate, dispense, transfer, mix; edges carry attributes (volume, liquid identity, tip id, speed, timestamp).
- State: per-node attributes (current volume, composition, contamination risk), per-edge attributes (history), and global context (robot capabilities, timing, environment).
- DAG Invariant: Every edge (u,v) must satisfy timestamp(u) < timestamp(v), ensuring no cycles can exist.
At each step we add one or more edges that modify node states while preserving the DAG property. Generation ends when goals are satisfied (target mixture, plate layout) or no valid actions remain.
Why attribute models aren’t enough
Attribute-focused GNNs answer questions like “what volume should be in well A?” given a fixed graph. Workflow synthesis instead requires proposing valid connectivity changes. We need a model that:
- Proposes the next operation (an edge or set of edges), including its endpoints and attributes.
- Respects constraints (conservation of volume, sterility, capacity, tool availability).
- Plans long-horizon sequences to reach targets.
Modeling approaches for connectivity generation
There are several viable families, which can be combined:
1) Autoregressive edge generation
- Factorize p(protocol) into a sequence of edge additions. At each step, encode the current graph with a message-passing GNN; a policy head scores candidate (source node, op type, target node, attributes).
- Sampling: top-k or beam search with constraint masking.
- Benefits: precise control and easy constraint integration; drawbacks: long horizons.
2) Diffusion or denoising over graphs
- Start from a noisy action plan and denoise into a valid workflow using a GNN denoiser conditioned on task goals and instrument state.
- Useful for exploring diverse plans; requires careful constraint handling during sampling.
3) Constraint-satisfying planning with neural guidance
- Use a symbolic planner or MILP/CP-SAT to enforce hard physical constraints; use a GNN to learn heuristics (cost-to-go, action priors) that guide the search.
- Strong guarantees with improved speed/quality from learning.
4) Imitation + RL hybrid
- Train the policy with behavior cloning on historical protocols; fine-tune with RL using a simulator that implements lab physics and penalties for invalid or unsafe actions.
Let me elaborate on each approach with technical details and examples:
1. Autoregressive Edge Generation
This approach treats protocol generation as a sequence modeling problem where each step adds one or more edges to the growing workflow DAG.
Architecture Details:
- Encoder: A k-layer message-passing GNN (e.g., GraphSAGE, GAT) processes the current graph state
- Policy Head: Multi-output network that predicts:
- Operation type (categorical: transfer, mix, aspirate, dispense, etc.)
- Source node selection (pointer network over available nodes)
- Target node selection (pointer network with feasibility masking)
- Continuous attributes (volume, speed, temperature) with bounded distributions
- Timestamp assignment: Critical for maintaining DAG property
Training Strategy:
- Teacher forcing: use ground truth previous actions during training
- Scheduled sampling: gradually transition from teacher forcing to autoregressive generation
- Constraint masking: zero out probabilities for invalid actions (e.g., transferring from empty wells, creating cycles)
- DAG enforcement: Ensure timestamp(u) < timestamp(v) for all new edges (u,v)
Example Implementation:
# Simplified pseudocode
def generate_step(current_graph, goal_embedding):
# Encode current state
node_embeddings = gnn_encoder(current_graph)
# Predict next operation
op_type = op_classifier(node_embeddings, goal_embedding)
# Select source and target with pointer networks
source_logits = source_pointer(node_embeddings, op_type)
target_logits = target_pointer(node_embeddings, op_type, source_logits)
# Apply feasibility masks
source_logits = source_logits * source_feasibility_mask
target_logits = target_logits * target_feasibility_mask
# Sample and return action
return sample_action(op_type, source_logits, target_logits)
Advantages:
- Direct control over generation process
- Easy to integrate hard constraints via masking
- Interpretable: each action is explicit and traceable
- Can use beam search for better planning
Challenges:
- Sequential nature limits parallelization
- Error accumulation over long sequences
- Requires careful curriculum learning for complex protocols
2. Diffusion/Denoising over Graphs
This approach starts from a noisy, potentially invalid workflow and progressively denoises it into a valid protocol DAG.
Architecture Details:
- Noise Schedule: Gradually add noise to a target protocol over T timesteps
- Denoiser: GNN that predicts the clean protocol given noisy input and timestep
- Conditioning: Task goals, instrument constraints, and current lab state
- DAG Structure: Denoiser must learn to respect temporal ordering constraints
Training Process:
- Start with clean protocols from dataset
- Add Gaussian noise over T timesteps
- Train denoiser to predict original protocol given noisy version and timestep
- Use classifier-free guidance for better control
Example Implementation:
def diffusion_generate(goal_embedding, num_steps=1000):
# Start with pure noise
noisy_protocol = torch.randn(protocol_shape)
for t in reversed(range(num_steps)):
# Predict clean protocol
predicted_clean = denoiser(noisy_protocol, t, goal_embedding)
# Apply constraint projection
predicted_clean = project_to_constraints(predicted_clean)
# Denoise step
noisy_protocol = denoise_step(noisy_protocol, predicted_clean, t)
return noisy_protocol
Advantages:
- Can generate diverse, high-quality protocols
- Natural handling of global structure
- Good at exploring solution space
Challenges:
- Requires many denoising steps
- Constraint satisfaction during sampling is tricky
- Less interpretable than autoregressive methods
3. Constraint-Satisfying Planning with Neural Guidance
This hybrid approach combines symbolic planning with learned heuristics from GNNs.
Architecture Details:
- Symbolic Planner: MILP/CP-SAT solver that enforces hard constraints including DAG structure
- Neural Heuristic: GNN that learns to guide the search efficiently
- Integration: Use GNN predictions to order search branches or estimate costs
- Temporal Constraints: Solver ensures timestamp ordering and prevents cycles
Training Strategy:
- Collect planning traces from solver
- Train GNN to predict:
- Action priors (which operations are likely useful)
- Cost-to-go estimates (how expensive remaining steps will be)
- Constraint violation likelihood
Example Implementation:
def guided_planning(initial_state, goal):
# Encode state with GNN
state_embedding = gnn_encoder(initial_state)
# Use in symbolic planner
plan = symbolic_planner(
initial_state,
goal,
action_heuristics=action_priors,
cost_heuristics=cost_estimate
)
return plan
Advantages:
- Guaranteed constraint satisfaction
- Can leverage decades of optimization research
- Neural guidance improves search efficiency
Challenges:
- Requires symbolic constraint modeling
- Integration complexity
- May be slower than pure neural approaches
4. Imitation + RL Hybrid
This approach starts with supervised learning on historical data and refines with reinforcement learning.
Architecture Details:
- Behavior Cloning: Initial training on expert demonstrations
- RL Fine-tuning: Use simulator rewards to improve policy
- Hybrid Loss: Combine imitation and RL objectives
Training Phases:
- Phase 1: Train policy to mimic expert protocols
- Phase 2: Use RL to optimize for efficiency, robustness, and safety
- Phase 3: Iterative improvement with human feedback
Example Implementation:
def hybrid_training(expert_data, simulator):
# Phase 1: Behavior cloning
policy = train_imitation(expert_data)
# Phase 2: RL fine-tuning
for episode in range(num_episodes):
state = simulator.reset()
done = False
while not done:
action = policy(state)
next_state, reward, done = simulator.step(action)
# Update policy with RL algorithm (e.g., PPO)
policy.update(state, action, reward, next_state)
state = next_state
Advantages:
- Starts with reasonable behavior
- Can optimize for complex objectives
- Combines best of supervised and RL
Challenges:
- Requires high-quality simulator
- RL training can be unstable
- Need to balance imitation vs. exploration
Combining Approaches
The most effective systems often combine multiple approaches:
- Use autoregressive generation for high-level structure
- Apply diffusion for local refinements
- Use symbolic planning for critical safety constraints
- Fine-tune with RL for efficiency optimization
Action parameterization and constraint masking
To keep the action space tractable:
- Predict operation type first (transfer/mix/thermo step), then endpoints via pointer networks over node embeddings, then continuous attributes (e.g., volume) with bounded distributions.
- Timestamp assignment: Each new operation must have a timestamp greater than all previous operations to maintain DAG structure.
- Apply masks derived from current state: available tips, sufficient volume at source, capacity at destination, deck reachability, sterility compatibility.
- DAG constraint masking: Prevent edges that would create cycles or violate temporal ordering.
- Enforce invariants by projection (e.g., clip volumes to feasible ranges) and by rejecting invalid samples.
State representation details
- Node features: current volume, composition embedding (e.g., learned from reagent ontology), temperature, contamination flags, container geometry.
- Edge features: operation type, executed volume, time since last action, tip id.
- Global features: assay goal embedding, allowed instruments, remaining time budget.
- Temporal encoding: append step index or use recurrent GNN layers to retain history.
Training signals and datasets
- Imitation data: parse existing protocols (e.g., from OT-2, Hamilton scripts) into action graphs.
- Supervision: next-edge classification, endpoint selection, and attribute regression; auxiliary losses for state prediction (e.g., next-node volume) improve stability.
- Negative sampling: generate near-miss actions (slightly over volume, wrong tip) to sharpen constraint awareness.
Evaluation metrics
- Validity: fraction of generated steps passing all constraints; zero spills/overflows; no cross-contamination.
- Goal satisfaction: assay success rate, target composition accuracy.
- Efficiency: action count, total time, tip consumption, deck moves.
- Diversity: unique valid workflows per goal.
- Sim-to-real: execution success on hardware with minimal edits.
Minimal prototype sketch
Outline of an autoregressive generator with constraint masking:
- Encode current graph with a k-layer message-passing GNN.
- Predict operation type with a masked classifier.
- Select source and target nodes using pointer heads over node embeddings with feasibility masks.
- Assign timestamp: Ensure new operation timestamp > all previous timestamps to maintain DAG.
- Regress attributes (volume, speed) with bounded outputs; project to valid ranges.
- Validate DAG: Check that no cycles would be created by the new edge.
- Update node states and append the new edge; repeat until done.
- Use beam search for better plans; score beams by learned value function + hard constraint checks.
Concrete Example: Variable Serial Dilution Network Discovery
Let’s implement a simplified version of the autoregressive approach for discovering the network required for a variable serial dilution on a 96-well plate. This example shows how DAG constraints and connectivity generation work in practice.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List, Tuple, Dict, Optional
from dataclasses import dataclass
from enum import Enum
# Define operation types
class OpType(Enum):
ASPIRATE = "aspirate"
DISPENSE = "dispense"
TRANSFER = "transfer"
MIX = "mix"
@dataclass
class LiquidState:
"""Represents the state of liquid in a well"""
volume: float # Current volume in μL
concentration: float # Concentration of target compound
contamination_risk: float # Risk of cross-contamination (0-1)
timestamp: int # When this state was created
@dataclass
class Operation:
"""Represents a liquid handling operation"""
op_type: OpType
source_well: Optional[str] # None for aspirate from reservoir
target_well: str
volume: float
timestamp: int
tip_id: str
class DilutionWorkflow:
"""Represents the current state of a dilution workflow"""
def __init__(self, plate_rows: int = 8, plate_cols: int = 12):
self.plate_rows = plate_rows
self.plate_cols = plate_cols
self.wells = {} # well_id -> LiquidState
self.operations = [] # List of Operation objects
self.available_tips = [f"tip_{i}" for i in range(8)] # 8-channel pipette
self.timestamp = 0
# Initialize source wells (e.g., A1 has stock solution)
self.wells["A1"] = LiquidState(volume=200.0, concentration=1000.0,
contamination_risk=0.0, timestamp=0)
def get_well_id(self, row: int, col: int) -> str:
"""Convert row/col to well ID (e.g., A1, B2)"""
return f"{chr(65 + row)}{col + 1}"
def can_transfer(self, source: str, target: str, volume: float) -> bool:
"""Check if a transfer operation is valid"""
if source not in self.wells or target not in self.wells:
return False
source_state = self.wells[source]
target_state = self.wells[target]
# Check volume constraints
if source_state.volume < volume:
return False
# Check contamination risk (can't transfer to contaminated wells)
if target_state.contamination_risk > 0.5:
return False
# Check DAG constraint: source must be created before target
if source_state.timestamp >= target_state.timestamp:
return False
return True
def add_operation(self, op: Operation):
"""Add an operation and update well states"""
self.operations.append(op)
self.timestamp = max(self.timestamp, op.timestamp) + 1
if op.op_type == OpType.TRANSFER:
# Update source well
if op.source_well:
source_state = self.wells[op.source_well]
source_state.volume -= op.volume
source_state.timestamp = self.timestamp
# Update target well
if op.target_well not in self.wells:
self.wells[op.target_well] = LiquidState(
volume=0.0, concentration=0.0,
contamination_risk=0.0, timestamp=self.timestamp
)
target_state = self.wells[op.target_well]
target_state.volume += op.volume
# Calculate new concentration (weighted average)
if target_state.volume > 0:
if op.source_well:
source_conc = self.wells[op.source_well].concentration
target_state.concentration = (
(target_state.volume - op.volume) * target_state.concentration +
op.volume * source_conc
) / target_state.volume
# Update contamination risk
if op.source_well:
source_risk = self.wells[op.source_well].contamination_risk
target_state.contamination_risk = max(
target_state.contamination_risk, source_risk
)
class DilutionNetworkGenerator:
"""Generates dilution networks using a simplified GNN-like approach"""
def __init__(self, hidden_dim: int = 64):
self.hidden_dim = hidden_dim
# Simple MLPs for different prediction tasks
self.op_type_predictor = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, len(OpType))
)
self.source_predictor = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim), # node + global context
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
self.target_predictor = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
self.volume_predictor = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
nn.Sigmoid() # Output 0-1, scale to actual volume
)
def encode_workflow_state(self, workflow: DilutionWorkflow) -> Dict[str, torch.Tensor]:
"""Encode the current workflow state into node and global embeddings"""
# Simple encoding: concatenate well features
well_features = []
well_ids = []
for well_id in workflow.wells:
state = workflow.wells[well_id]
features = [
state.volume / 200.0, # Normalize volume
state.concentration / 1000.0, # Normalize concentration
state.contamination_risk,
state.timestamp / 100.0 # Normalize timestamp
]
well_features.append(features)
well_ids.append(well_id)
# Pad to fixed size for batch processing
max_wells = workflow.plate_rows * workflow.plate_cols
while len(well_features) < max_wells:
well_features.append([0.0, 0.0, 0.0, 0.0])
well_ids.append("")
# Global context: goal concentration, remaining wells to fill
target_concentration = 100.0 # Example target
remaining_wells = max_wells - len([w for w in workflow.wells.values() if w.volume > 0])
global_features = [
target_concentration / 1000.0,
remaining_wells / max_wells,
workflow.timestamp / 100.0
]
return {
'well_features': torch.tensor(well_features, dtype=torch.float32),
'well_ids': well_ids,
'global_features': torch.tensor(global_features, dtype=torch.float32)
}
def predict_next_operation(self, workflow: DilutionWorkflow) -> Operation:
"""Predict the next operation using the current workflow state"""
# Encode current state
encoded = self.encode_workflow_state(workflow)
well_features = encoded['well_features']
global_features = encoded['global_features']
# Simple "GNN-like" processing: aggregate well features
node_embeddings = well_features @ torch.randn(4, self.hidden_dim) # Simplified
# Predict operation type
global_context = global_features.unsqueeze(0).expand(node_embeddings.shape[0], -1)
combined_features = torch.cat([node_embeddings, global_context], dim=1)
op_type_logits = self.op_type_predictor(node_embeddings.mean(dim=0))
op_type = OpType(list(OpType)[op_type_logits.argmax().item()])
# Predict source well (with masking)
source_scores = self.source_predictor(combined_features).squeeze()
source_mask = torch.zeros_like(source_scores)
# Mask: only wells with liquid can be sources
for i, well_id in enumerate(encoded['well_ids']):
if well_id in workflow.wells and workflow.wells[well_id].volume > 0:
source_mask[i] = 1.0
source_scores = source_scores * source_mask
source_idx = source_scores.argmax().item()
source_well = encoded['well_ids'][source_idx] if source_mask[source_idx] > 0 else None
# Predict target well (with masking)
target_scores = self.target_predictor(combined_features).squeeze()
target_mask = torch.zeros_like(target_scores)
# Mask: prefer empty wells or wells that need dilution
for i, well_id in enumerate(encoded['well_ids']):
if well_id not in workflow.wells or workflow.wells[well_id].volume < 50:
target_mask[i] = 1.0
target_scores = target_scores * target_mask
target_idx = target_scores.argmax().item()
target_well = encoded['well_ids'][target_idx]
# Predict volume
volume_logit = self.volume_predictor(node_embeddings.mean(dim=0))
volume = volume_logit.item() * 50.0 # Scale to 0-50 μL range
# Ensure DAG constraint: timestamp must be greater than all previous
timestamp = workflow.timestamp + 1
# Select available tip
tip_id = workflow.available_tips[0] # Simplified
return Operation(
op_type=op_type,
source_well=source_well,
target_well=target_well,
volume=volume,
timestamp=timestamp,
tip_id=tip_id
)
def generate_dilution_workflow(target_concentrations: List[float],
max_operations: int = 50) -> DilutionWorkflow:
"""Generate a complete dilution workflow"""
workflow = DilutionWorkflow()
generator = DilutionNetworkGenerator()
operations_count = 0
while operations_count < max_operations:
# Check if we've achieved our goals
filled_wells = [w for w in workflow.wells.values() if w.volume > 0]
if len(filled_wells) >= len(target_concentrations):
# Check if concentrations are close enough
achieved_concentrations = [w.concentration for w in filled_wells[:len(target_concentrations)]]
if all(abs(ac - tc) < 50 for ac, tc in zip(achieved_concentrations, target_concentrations)):
break
# Predict next operation
try:
next_op = generator.predict_next_operation(workflow)
# Validate operation
if next_op.source_well and next_op.target_well:
if workflow.can_transfer(next_op.source_well, next_op.target_well, next_op.volume):
workflow.add_operation(next_op)
operations_count += 1
print(f"Added operation: {next_op.op_type.value} {next_op.volume:.1f}μL "
f"from {next_op.source_well} to {next_op.target_well}")
else:
print(f"Invalid operation: {next_op.op_type.value} {next_op.volume:.1f}μL "
f"from {next_op.source_well} to {next_op.target_well}")
else:
# Handle aspirate/dispense operations
workflow.add_operation(next_op)
operations_count += 1
except Exception as e:
print(f"Error generating operation: {e}")
break
return workflow
# Example usage
if __name__ == "__main__":
# Generate a workflow for 8 different concentrations
target_concentrations = [800, 600, 400, 200, 100, 50, 25, 12.5]
print("Generating dilution workflow...")
workflow = generate_dilution_workflow(target_concentrations)
print(f"\nGenerated {len(workflow.operations)} operations")
print(f"Final workflow has {len(workflow.wells)} wells with liquid")
# Show final concentrations
print("\nFinal well states:")
for well_id, state in sorted(workflow.wells.items()):
if state.volume > 0:
print(f"{well_id}: {state.volume:.1f}μL, {state.concentration:.1f} ng/μL")
# Verify DAG property
timestamps = [op.timestamp for op in workflow.operations]
if timestamps == sorted(timestamps):
print("\n✓ DAG constraint satisfied: all operations are temporally ordered")
else:
print("\n✗ DAG constraint violated: operations are not temporally ordered")
This example demonstrates:
- DAG Enforcement: Each operation gets a timestamp greater than all previous operations
- Constraint Masking: Source wells must have liquid, target wells should be empty or need dilution
- State Updates: Well volumes and concentrations are updated after each operation
- Validation: Operations are checked for feasibility before execution
- Goal-Oriented Generation: The workflow continues until target concentrations are achieved
The generator uses a simplified “GNN-like” approach with:
- Node embeddings based on well features (volume, concentration, contamination, timestamp)
- Global context (target concentration, remaining wells, current timestamp)
- Masked prediction for source/target selection
- Constraint validation to maintain physical and temporal consistency
Why GNNs fit this problem
Message passing aligns with local physical constraints while still capturing long-range goals through multiple hops and global features, as articulated in the Distill overview link. The core difference here is that we use the GNN not to label a fixed graph but to drive the creation of new connectivity under constraints.
Outlook
Bringing workflow generation to practice requires: a realistic simulator with rich constraints, curated protocol datasets, and careful interfaces to planners and robots. The architectural pieces above provide a path to move from classification to connectivity.
References:
- Sanchez-Lengeling, B., Reif, E., Pearce, A., Wiltschko, A. “A Gentle Introduction to Graph Neural Networks,” Distill (2021). Distill article.