Graph Examples¶
This section provides practical examples of using AugChem for graph-based molecular data augmentation.
Example 1: Basic Graph Augmentation¶
import torch
from torch_geometric.data import Data
from augchem.modules.graph.graphs_modules import augment_dataset
# Create sample molecular graphs
def create_sample_graph(num_nodes, num_edges):
x = torch.randn(num_nodes, 9) # 9 node features
edge_index = torch.randint(0, num_nodes, (2, num_edges))
edge_attr = torch.randn(num_edges, 4) # 4 edge features
return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
# Create a small dataset
graphs = [
create_sample_graph(10, 18),
create_sample_graph(12, 22),
create_sample_graph(8, 14),
create_sample_graph(15, 28)
]
# Apply augmentation
augmented_graphs = augment_dataset(
graphs=graphs,
augmentation_methods=['edge_drop', 'node_drop', 'feature_mask'],
edge_drop_rate=0.1,
node_drop_rate=0.05,
feature_mask_rate=0.15,
augment_percentage=0.5, # 50% more data
seed=42
)
print(f"Dataset expanded from {len(graphs)} to {len(augmented_graphs)} graphs")
Example 2: Processing Real Molecular Data¶
from rdkit import Chem
import torch
from torch_geometric.data import Data
from augchem.modules.graph.graphs_modules import augment_dataset
def smiles_to_graph(smiles):
"""Convert SMILES to PyTorch Geometric graph"""
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
# Node features
atom_features = []
for atom in mol.GetAtoms():
features = [
atom.GetAtomicNum(),
atom.GetDegree(),
atom.GetFormalCharge(),
int(atom.GetHybridization()),
int(atom.GetIsAromatic()),
atom.GetNumRadicalElectrons(),
atom.GetTotalNumHs(),
int(atom.IsInRing()),
atom.GetMass()
]
atom_features.append(features)
# Edge features
edge_indices = []
edge_attrs = []
for bond in mol.GetBonds():
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
edge_indices.extend([[i, j], [j, i]]) # Bidirectional
bond_features = [
bond.GetBondTypeAsDouble(),
int(bond.GetIsAromatic()),
int(bond.IsInRing()),
int(bond.GetIsConjugated())
]
edge_attrs.extend([bond_features, bond_features])
return Data(
x=torch.tensor(atom_features, dtype=torch.float),
edge_index=torch.tensor(edge_indices, dtype=torch.long).t().contiguous(),
edge_attr=torch.tensor(edge_attrs, dtype=torch.float)
)
# Example molecules
molecules = [
"CCO", # Ethanol
"CC(=O)O", # Acetic acid
"c1ccccc1", # Benzene
"CCN(CC)CC", # Triethylamine
"CC(C)O" # Isopropanol
]
# Convert to graphs
graphs = [smiles_to_graph(smiles) for smiles in molecules]
graphs = [g for g in graphs if g is not None]
# Augment the dataset
augmented = augment_dataset(
graphs=graphs,
augmentation_methods=['edge_drop', 'feature_mask', 'edge_perturb'],
augment_percentage=1.0, # Double the dataset
seed=42
)
print(f"Augmented {len(molecules)} molecules to {len(augmented)} graphs")
Example 3: Integration with Machine Learning¶
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
# Simple GNN for molecular property prediction
class MolecularGNN(nn.Module):
def __init__(self, num_node_features, hidden_dim=64):
super().__init__()
self.conv1 = GCNConv(num_node_features, hidden_dim)
self.conv2 = GCNConv(hidden_dim, hidden_dim)
self.conv3 = GCNConv(hidden_dim, hidden_dim)
self.classifier = nn.Linear(hidden_dim, 1)
self.dropout = nn.Dropout(0.2)
def forward(self, x, edge_index, batch):
x = F.relu(self.conv1(x, edge_index))
x = self.dropout(x)
x = F.relu(self.conv2(x, edge_index))
x = self.dropout(x)
x = F.relu(self.conv3(x, edge_index))
# Global pooling
x = global_mean_pool(x, batch)
return self.classifier(x)
# Prepare data with augmentation
original_graphs = graphs # From previous example
augmented_graphs = augment_dataset(
original_graphs,
augmentation_methods=['edge_drop', 'node_drop', 'feature_mask'],
augment_percentage=0.3
)
# Add dummy targets for demonstration
for graph in augmented_graphs:
graph.y = torch.randn(1) # Random property value
# Split data
train_graphs, test_graphs = train_test_split(
augmented_graphs, test_size=0.2, random_state=42
)
# Create data loaders
train_loader = DataLoader(train_graphs, batch_size=16, shuffle=True)
test_loader = DataLoader(test_graphs, batch_size=16, shuffle=False)
# Initialize model
model = MolecularGNN(num_node_features=9)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
# Training loop
model.train()
for epoch in range(10):
total_loss = 0
for batch in train_loader:
optimizer.zero_grad()
out = model(batch.x, batch.edge_index, batch.batch)
loss = criterion(out, batch.y.view(-1, 1))
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1:2d}, Loss: {total_loss/len(train_loader):.4f}")
print("Training completed!")
Example 4: Comparative Analysis¶
import matplotlib.pyplot as plt
import numpy as np
def analyze_augmentation_impact(original_graphs, augmented_graphs):
"""Analyze the impact of augmentation on dataset statistics"""
# Extract statistics
def get_stats(graphs):
nodes = [g.num_nodes for g in graphs]
edges = [g.edge_index.size(1) for g in graphs]
return nodes, edges
orig_nodes, orig_edges = get_stats(original_graphs)
aug_nodes, aug_edges = get_stats(augmented_graphs)
# Create comparison plots
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
# Node distribution
axes[0, 0].hist(orig_nodes, bins=20, alpha=0.7, label='Original', color='blue')
axes[0, 0].hist(aug_nodes, bins=20, alpha=0.7, label='Augmented', color='red')
axes[0, 0].set_xlabel('Number of Nodes')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].set_title('Node Count Distribution')
axes[0, 0].legend()
# Edge distribution
axes[0, 1].hist(orig_edges, bins=20, alpha=0.7, label='Original', color='blue')
axes[0, 1].hist(aug_edges, bins=20, alpha=0.7, label='Augmented', color='red')
axes[0, 1].set_xlabel('Number of Edges')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].set_title('Edge Count Distribution')
axes[0, 1].legend()
# Node vs Edge scatter
axes[1, 0].scatter(orig_nodes, orig_edges, alpha=0.6, label='Original', color='blue')
axes[1, 0].scatter(aug_nodes, aug_edges, alpha=0.6, label='Augmented', color='red')
axes[1, 0].set_xlabel('Number of Nodes')
axes[1, 0].set_ylabel('Number of Edges')
axes[1, 0].set_title('Nodes vs Edges Relationship')
axes[1, 0].legend()
# Statistics summary
stats_text = f"""Dataset Statistics:
Original Dataset:
• Graphs: {len(original_graphs)}
• Avg nodes: {np.mean(orig_nodes):.1f} ± {np.std(orig_nodes):.1f}
• Avg edges: {np.mean(orig_edges):.1f} ± {np.std(orig_edges):.1f}
Augmented Dataset:
• Graphs: {len(augmented_graphs)}
• Avg nodes: {np.mean(aug_nodes):.1f} ± {np.std(aug_nodes):.1f}
• Avg edges: {np.mean(aug_edges):.1f} ± {np.std(aug_edges):.1f}
Augmentation Impact:
• Size increase: {((len(augmented_graphs)/len(original_graphs))-1)*100:.1f}%
• Node diversity: {(np.std(aug_nodes)/np.std(orig_nodes)-1)*100:+.1f}%
• Edge diversity: {(np.std(aug_edges)/np.std(orig_edges)-1)*100:+.1f}%"""
axes[1, 1].text(0.05, 0.95, stats_text, transform=axes[1, 1].transAxes,
fontsize=10, verticalalignment='top', fontfamily='monospace',
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray"))
axes[1, 1].set_xlim(0, 1)
axes[1, 1].set_ylim(0, 1)
axes[1, 1].axis('off')
axes[1, 1].set_title('Dataset Comparison')
plt.tight_layout()
plt.show()
# Run analysis
analyze_augmentation_impact(graphs, augmented_graphs)
Example 5: Advanced Augmentation Strategy¶
from augchem.modules.graph.graphs_modules import (
edge_dropping, node_dropping, feature_masking, edge_perturbation
)
def create_diverse_augmentations(graph, num_variants=5):
"""Create diverse augmentations of a single graph"""
variants = []
for i in range(num_variants):
# Randomly select augmentation parameters
edge_drop_rate = np.random.uniform(0.05, 0.15)
node_drop_rate = np.random.uniform(0.02, 0.08)
mask_rate = np.random.uniform(0.10, 0.20)
# Apply different combinations
if i % 3 == 0:
# Edge-focused augmentation
variant = edge_dropping(graph, drop_rate=edge_drop_rate)
variant = feature_masking(variant, mask_rate=mask_rate)
variant.augmentation_type = "edge_focused"
elif i % 3 == 1:
# Node-focused augmentation
variant = node_dropping(graph, drop_rate=node_drop_rate)
variant = feature_masking(variant, mask_rate=mask_rate)
variant.augmentation_type = "node_focused"
else:
# Perturbation-focused augmentation
variant = edge_perturbation(graph, add_rate=0.03, remove_rate=0.05)
variant = feature_masking(variant, mask_rate=mask_rate)
variant.augmentation_type = "perturbation_focused"
variants.append(variant)
return variants
# Apply diverse augmentation to each graph
diverse_augmented = []
for i, graph in enumerate(graphs):
variants = create_diverse_augmentations(graph, num_variants=3)
diverse_augmented.extend(variants)
print(f"Graph {i}: created {len(variants)} variants")
print(f"Total augmented graphs: {len(diverse_augmented)}")
# Analyze augmentation types
augmentation_types = [g.augmentation_type for g in diverse_augmented]
type_counts = {t: augmentation_types.count(t) for t in set(augmentation_types)}
print("Augmentation type distribution:", type_counts)
Example 6: Individual Augmentation Techniques¶
from augchem.modules.graph.graphs_modules import (
edge_dropping, node_dropping, feature_masking, edge_perturbation
)
# Use a sample graph for demonstration
sample_graph = graphs[0] # Ethanol graph from previous example
print("Original graph statistics:")
print(f" Nodes: {sample_graph.num_nodes}")
print(f" Edges: {sample_graph.edge_index.size(1)}")
print(f" Node features shape: {sample_graph.x.shape}")
print(f" Edge features shape: {sample_graph.edge_attr.shape}")
print()
# 1. Edge Dropping
print("1. Edge Dropping:")
for drop_rate in [0.1, 0.2, 0.3]:
edge_dropped = edge_dropping(sample_graph, drop_rate=drop_rate)
edges_removed = sample_graph.edge_index.size(1) - edge_dropped.edge_index.size(1)
print(f" Drop rate {drop_rate:.1f}: {sample_graph.edge_index.size(1)} -> {edge_dropped.edge_index.size(1)} edges ({edges_removed} removed)")
print()
# 2. Node Dropping
print("2. Node Dropping:")
for drop_rate in [0.05, 0.1, 0.15]:
node_dropped = node_dropping(sample_graph, drop_rate=drop_rate)
nodes_removed = sample_graph.num_nodes - node_dropped.num_nodes
print(f" Drop rate {drop_rate:.2f}: {sample_graph.num_nodes} -> {node_dropped.num_nodes} nodes ({nodes_removed} removed)")
print()
# 3. Feature Masking
print("3. Feature Masking:")
for mask_rate in [0.1, 0.2, 0.3]:
feature_masked = feature_masking(sample_graph, mask_rate=mask_rate)
total_features = feature_masked.x.numel()
masked_features = (feature_masked.x == float('-inf')).sum().item()
print(f" Mask rate {mask_rate:.1f}: {masked_features}/{total_features} features masked ({masked_features/total_features*100:.1f}%)")
print()
# 4. Edge Perturbation
print("4. Edge Perturbation:")
perturbation_configs = [
(0.05, 0.05),
(0.1, 0.1),
(0.03, 0.07)
]
for add_rate, remove_rate in perturbation_configs:
edge_perturbed = edge_perturbation(sample_graph, add_rate=add_rate, remove_rate=remove_rate)
edge_change = edge_perturbed.edge_index.size(1) - sample_graph.edge_index.size(1)
print(f" Add {add_rate:.2f}, Remove {remove_rate:.2f}: {sample_graph.edge_index.size(1)} -> {edge_perturbed.edge_index.size(1)} edges ({edge_change:+d} net change)")
Example 7: Batch Processing and Quality Control¶
from torch_geometric.loader import DataLoader
def validate_graph_quality(graphs):
"""Check graph quality after augmentation"""
issues = []
for i, graph in enumerate(graphs):
# Check for isolated nodes
edge_index = graph.edge_index
if edge_index.size(1) > 0:
connected_nodes = torch.unique(edge_index.flatten())
isolated_nodes = graph.num_nodes - len(connected_nodes)
if isolated_nodes > 0:
issues.append(f"Graph {i}: {isolated_nodes} isolated nodes")
# Check for self-loops
if edge_index.size(1) > 0:
self_loops = (edge_index[0] == edge_index[1]).sum().item()
if self_loops > 0:
issues.append(f"Graph {i}: {self_loops} self-loops")
# Check for negative features (from masking)
if torch.any(graph.x == float('-inf')):
masked_count = (graph.x == float('-inf')).sum().item()
issues.append(f"Graph {i}: {masked_count} masked features")
return issues
# Create larger dataset for demonstration
larger_graphs = []
for i in range(20):
smiles_list = [
"CCO", "CC(=O)O", "c1ccccc1", "CCN(CC)CC", "CC(C)O",
"CC(C)(C)O", "CC=O", "C1CCCCC1", "c1ccc2ccccc2c1", "CCCCO"
]
graph = smiles_to_graph(smiles_list[i % len(smiles_list)])
if graph is not None:
larger_graphs.append(graph)
print(f"Created dataset with {len(larger_graphs)} graphs")
# Apply batch augmentation
batch_augmented = augment_dataset(
graphs=larger_graphs,
augmentation_methods=['edge_drop', 'node_drop', 'feature_mask', 'edge_perturb'],
edge_drop_rate=0.12,
node_drop_rate=0.08,
feature_mask_rate=0.20,
edge_add_rate=0.05,
edge_remove_rate=0.05,
augment_percentage=0.6,
seed=42
)
print(f"Batch augmentation: {len(larger_graphs)} -> {len(batch_augmented)} graphs")
# Quality validation
print("\nQuality validation:")
original_issues = validate_graph_quality(larger_graphs)
augmented_issues = validate_graph_quality(batch_augmented)
print(f"Original dataset issues: {len(original_issues)}")
if original_issues:
for issue in original_issues[:5]: # Show first 5 issues
print(f" {issue}")
print(f"Augmented dataset issues: {len(augmented_issues)}")
if augmented_issues:
for issue in augmented_issues[:5]: # Show first 5 issues
print(f" {issue}")
# Create DataLoader for training
train_loader = DataLoader(batch_augmented, batch_size=32, shuffle=True)
print(f"\nCreated DataLoader with batch size 32")
print(f"Number of batches: {len(train_loader)}")
# Examine first batch
for batch in train_loader:
print(f"First batch: {batch.num_graphs} graphs, {batch.x.size(0)} total nodes")
break
Example 8: Real-World Drug Discovery Application¶
# Simulate a more complex drug discovery scenario
import pandas as pd
# Create realistic molecular dataset
drug_molecules = [
"CC(C)Cc1ccc(cc1)[C@@H](C)C(=O)O", # Ibuprofen
"CC(=O)Oc1ccccc1C(=O)O", # Aspirin
"CC(=O)Nc1ccc(cc1)O", # Paracetamol
"CN1C=NC2=C1C(=O)N(C(=O)N2C)C", # Caffeine
"Clc1ccc(cc1)C(c2ccccc2)N3CCCC3", # Loratadine
"CC12CCC3C(C1CCC2O)CCC4=CC(=O)CCC34C", # Testosterone
"CCN(CC)CCNC(=O)c1cc(ccc1OC)S(=O)(=O)N", # Sulpiride
"Cc1ccc(cc1)C(=O)c2ccc(cc2)N(C)C", # Michler's ketone
]
# Convert to graphs
drug_graphs = []
drug_names = ["Ibuprofen", "Aspirin", "Paracetamol", "Caffeine",
"Loratadine", "Testosterone", "Sulpiride", "Michler's ketone"]
for smiles, name in zip(drug_molecules, drug_names):
graph = smiles_to_graph(smiles)
if graph is not None:
graph.name = name
graph.y = torch.randn(1) # Simulated activity
drug_graphs.append(graph)
print(f"Drug discovery dataset: {len(drug_graphs)} compounds")
# Apply pharmaceutical-grade augmentation
pharma_augmented = augment_dataset(
graphs=drug_graphs,
augmentation_methods=['edge_drop', 'feature_mask', 'edge_perturb'],
edge_drop_rate=0.08, # Conservative for drugs
feature_mask_rate=0.12, # Preserve chemical meaning
edge_add_rate=0.03, # Minimal structural changes
edge_remove_rate=0.05,
augment_percentage=0.4, # 40% expansion
seed=42
)
print(f"Pharmaceutical augmentation: {len(drug_graphs)} -> {len(pharma_augmented)} compounds")
# Analyze by original compound
print("\nAugmentation breakdown by compound:")
compound_counts = {}
for graph in pharma_augmented:
if hasattr(graph, 'name'):
compound_counts[graph.name] = compound_counts.get(graph.name, 0) + 1
else:
# This is an augmented graph, try to find parent
compound_counts['Augmented'] = compound_counts.get('Augmented', 0) + 1
for compound, count in compound_counts.items():
print(f" {compound}: {count} variants")
# Prepare for virtual screening
virtual_library = DataLoader(pharma_augmented, batch_size=16, shuffle=False)
print(f"\nVirtual screening library prepared:")
print(f" Total compounds: {len(pharma_augmented)}")
print(f" Batches for screening: {len(virtual_library)}")
# Simulate screening results
screening_results = []
for batch in virtual_library:
# Simulate screening scores
scores = torch.randn(batch.num_graphs)
screening_results.extend(scores.tolist())
print(f" Screening completed: {len(screening_results)} scores generated")
print(f" Top hit score: {max(screening_results):.3f}")
print(f" Mean score: {sum(screening_results)/len(screening_results):.3f}")
These examples demonstrate the comprehensive capabilities of AugChem's graph augmentation toolkit for molecular research, from basic data expansion to sophisticated drug discovery applications.