Skip to content

Graph Augmentation Methods

This section covers the advanced graph augmentation techniques available in Augchem, designed specifically for molecular graphs using PyTorch Geometric.

Core Augmentation Functions

Edge Dropping

augchem.modules.graph.graphs_modules.edge_dropping(data: Data, drop_rate: float = 0.1) -> Data

Remove complete bidirectional edges from the graph (edge dropping)

Parameters:

Name Type Description Default
data Data

torch_geometric graph

required
drop_rate float

Bidirectional edge removal rate (0.0 to 1.0)

0.1

Returns:

Type Description
Data

Graph with edges removed

Source code in augchem\modules\graph\graphs_modules.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def edge_dropping(data: Data, drop_rate: float = 0.1) -> Data:
    """
    Remove complete bidirectional edges from the graph (edge dropping)

    Args:
        data: torch_geometric graph
        drop_rate: Bidirectional edge removal rate (0.0 to 1.0)

    Returns:
        Graph with edges removed
    """
    if data.edge_index.size(1) == 0:
        return data.clone()

    edge_set = set()
    for i in range(data.edge_index.size(1)):
        src, dst = data.edge_index[0, i].item(), data.edge_index[1, i].item()
        edge_pair = tuple(sorted([src, dst]))
        edge_set.add(edge_pair)

    unique_edges = list(edge_set)
    num_unique_edges = len(unique_edges)

    if num_unique_edges == 0:
        return data.clone()

    num_to_drop = max(1, int(num_unique_edges * drop_rate))

    edges_to_drop = set(unique_edges[:num_to_drop])

    keep_mask = []
    for i in range(data.edge_index.size(1)):
        src, dst = data.edge_index[0, i].item(), data.edge_index[1, i].item()
        edge_pair = tuple(sorted([src, dst]))

        keep_mask.append(edge_pair not in edges_to_drop)

    keep_mask = torch.tensor(keep_mask, dtype=torch.bool)

    new_edge_index = data.edge_index[:, keep_mask]

    new_edge_attr = None
    if hasattr(data, 'edge_attr') and data.edge_attr is not None and data.edge_attr.size(0) > 0:
        new_edge_attr = data.edge_attr[keep_mask]

    new_data = Data(
        x=data.x.clone(),
        edge_index=new_edge_index,
        edge_attr=new_edge_attr,
        num_nodes=data.num_nodes
    )

    if hasattr(data, 'y') and data.y is not None:
        new_data.y = data.y.clone()

    return new_data

Node Dropping

augchem.modules.graph.graphs_modules.node_dropping(data: Data, drop_rate: float = 0.1) -> Data

Remove nodes randomly from the graph (node dropping)

Parameters:

Name Type Description Default
data Data

torch_geometric graph

required
drop_rate float

Node removal rate (0.0 to 1.0)

0.1

Returns:

Type Description
Data

Graph with nodes removed

Source code in augchem\modules\graph\graphs_modules.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def node_dropping(data: Data, drop_rate: float = 0.1) -> Data:
    """
    Remove nodes randomly from the graph (node dropping)

    Args:
        data: torch_geometric graph
        drop_rate: Node removal rate (0.0 to 1.0)

    Returns:
        Graph with nodes removed
    """
    if data.num_nodes <= 1:
        return data.clone()

    num_nodes = data.num_nodes
    num_to_drop = max(1, int(num_nodes * drop_rate))

    nodes_to_keep = torch.randperm(num_nodes)[num_to_drop:]
    nodes_to_keep = torch.sort(nodes_to_keep)[0]

    node_mapping = torch.full((num_nodes,), -1, dtype=torch.long)
    node_mapping[nodes_to_keep] = torch.arange(len(nodes_to_keep))

    edge_mask = (node_mapping[data.edge_index[0]] >= 0) & (node_mapping[data.edge_index[1]] >= 0)

    new_edge_attr = None
    if hasattr(data, 'edge_attr') and data.edge_attr is not None and data.edge_attr.size(0) > 0:
        if edge_mask.sum() > 0:
            new_edge_attr = data.edge_attr[edge_mask]
        else:
            new_edge_attr = torch.empty((0, data.edge_attr.size(1)), dtype=torch.float)

    if edge_mask.sum() == 0:
        new_edge_index = torch.empty((2, 0), dtype=torch.long)
    else:
        new_edge_index = node_mapping[data.edge_index[:, edge_mask]]

    new_data = Data(
        x=data.x[nodes_to_keep],
        edge_index=new_edge_index,
        edge_attr=new_edge_attr,
        num_nodes=len(nodes_to_keep)
    )

    if hasattr(data, 'y') and data.y is not None:
        new_data.y = data.y.clone()

    return new_data

Feature Masking

augchem.modules.graph.graphs_modules.feature_masking(data: Data, mask_rate: float = 0.1) -> Data

Mask node features randomly (feature masking)

Parameters:

Name Type Description Default
data Data

torch_geometric graph

required
mask_rate float

Feature masking rate (0.0 to 1.0)

0.1

Returns:

Type Description
Data

Graph with masked features

Source code in augchem\modules\graph\graphs_modules.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def feature_masking(data: Data, mask_rate: float = 0.1) -> Data:
    """
    Mask node features randomly (feature masking)

    Args:
        data: torch_geometric graph
        mask_rate: Feature masking rate (0.0 to 1.0)

    Returns:
        Graph with masked features
    """
    new_data = data.clone()

    mask_value = float('-inf')

    if new_data.x.size(0) == 0:
        return new_data

    mask = torch.rand_like(new_data.x) < mask_rate

    new_data.x = new_data.x.clone()
    new_data.x[mask] = mask_value

    return new_data

Edge Perturbation

augchem.modules.graph.graphs_modules.edge_perturbation(data: Data, add_rate: float = 0.05, remove_rate: float = 0.05) -> Data

Perturb the graph by adding and removing complete bidirectional edges (edge perturbation)

Parameters:

Name Type Description Default
data Data

torch_geometric graph

required
add_rate float

Bidirectional connection addition rate

0.05
remove_rate float

Bidirectional connection removal rate

0.05

Returns:

Type Description
Data

Perturbed graph

Source code in augchem\modules\graph\graphs_modules.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def edge_perturbation(data: Data, add_rate: float = 0.05, remove_rate: float = 0.05) -> Data:
    """
    Perturb the graph by adding and removing complete bidirectional edges (edge perturbation)

    Args:
        data: torch_geometric graph
        add_rate: Bidirectional connection addition rate
        remove_rate: Bidirectional connection removal rate

    Returns:
        Perturbed graph
    """
    perturbed_data = edge_dropping(data, remove_rate)

    existing_bidirectional = set()
    for i in range(perturbed_data.edge_index.size(1)):
        src, dst = perturbed_data.edge_index[0, i].item(), perturbed_data.edge_index[1, i].item()
        edge_pair = tuple(sorted([src, dst]))
        existing_bidirectional.add(edge_pair)

    num_nodes = data.num_nodes
    max_possible_connections = num_nodes * (num_nodes - 1) // 2
    current_connections = len(existing_bidirectional)
    available_connections = max_possible_connections - current_connections

    num_connections_to_add = int(available_connections * add_rate)

    if num_connections_to_add > 0:
        all_possible_connections = set()
        for i in range(num_nodes):
            for j in range(i + 1, num_nodes):
                all_possible_connections.add((i, j))

        available_connections_list = list(all_possible_connections - existing_bidirectional)

        if available_connections_list:
            torch.manual_seed(42)
            num_to_add = min(num_connections_to_add, len(available_connections_list))

            indices = torch.randperm(len(available_connections_list))[:num_to_add]

            new_bidirectional_edges = []
            for idx in indices:
                src, dst = available_connections_list[idx.item()]
                new_bidirectional_edges.extend([[src, dst], [dst, src]])

            if new_bidirectional_edges:
                new_edge_index = torch.tensor(new_bidirectional_edges, dtype=torch.long).t()

                perturbed_data.edge_index = torch.cat([perturbed_data.edge_index, new_edge_index], dim=1)

                if hasattr(perturbed_data, 'edge_attr') and perturbed_data.edge_attr is not None and perturbed_data.edge_attr.size(0) > 0:
                    mean_edge_attr = perturbed_data.edge_attr.mean(dim=0, keepdim=True)
                    new_edge_attrs = mean_edge_attr.repeat(new_edge_index.size(1), 1)
                    perturbed_data.edge_attr = torch.cat([perturbed_data.edge_attr, new_edge_attrs], dim=0)

    return perturbed_data

Dataset Augmentation

augchem.modules.graph.graphs_modules.augment_dataset(graphs: List[Data], augmentation_methods: List[str], edge_drop_rate: float = 0.1, node_drop_rate: float = 0.1, feature_mask_rate: float = 0.1, edge_add_rate: float = 0.05, edge_remove_rate: float = 0.05, augment_percentage: float = 0.2, seed: int = 42) -> List[Data]

Apply data augmentation techniques to a list of graphs.

Parameters:

Name Type Description Default
graphs List[Data]

List of torch_geometric Data objects representing the graphs

required
augmentation_methods List[str]

List of methods ['edge_drop', 'node_drop', 'feature_mask', 'edge_perturb']

required
edge_drop_rate float

Rate of edge removal (0.0 to 1.0)

0.1
node_drop_rate float

Rate of node removal (0.0 to 1.0)

0.1
feature_mask_rate float

Rate of feature masking (0.0 to 1.0)

0.1
edge_add_rate float

Rate of edge addition for perturbation

0.05
edge_remove_rate float

Rate of edge removal for perturbation

0.05
augment_percentage float

Size of the augmented dataset as a fraction of the original

0.2
seed int

Seed for reproducibility

42

Returns:

Type Description
List[Data]

List of augmented graphs (original + augmented)

Raises:

Type Description
ValueError

If unknown augmentation methods are specified

Source code in augchem\modules\graph\graphs_modules.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
def augment_dataset(graphs: List[Data], augmentation_methods: List[str], 
                          edge_drop_rate: float = 0.1, node_drop_rate: float = 0.1, 
                          feature_mask_rate: float = 0.1, edge_add_rate: float = 0.05,
                          edge_remove_rate: float = 0.05, augment_percentage: float = 0.2, 
                          seed: int = 42) -> List[Data]:
    """
    Apply data augmentation techniques to a list of graphs.

    Args:
        graphs: List of torch_geometric Data objects representing the graphs
        augmentation_methods: List of methods ['edge_drop', 'node_drop', 'feature_mask', 'edge_perturb']
        edge_drop_rate: Rate of edge removal (0.0 to 1.0)
        node_drop_rate: Rate of node removal (0.0 to 1.0)
        feature_mask_rate: Rate of feature masking (0.0 to 1.0)
        edge_add_rate: Rate of edge addition for perturbation
        edge_remove_rate: Rate of edge removal for perturbation
        augment_percentage: Size of the augmented dataset as a fraction of the original
        seed: Seed for reproducibility

    Returns:
        List of augmented graphs (original + augmented)

    Raises:
        ValueError: If unknown augmentation methods are specified
    """

    if not graphs:
        raise ValueError("List of graphs cannot be empty")

    valid_methods = {'edge_drop', 'node_drop', 'feature_mask', 'edge_perturb'}
    for method in augmentation_methods:
        if method not in valid_methods:
            raise ValueError(f"Unknown augmentation method: {method}. Valid methods: {valid_methods}")


    rng = np.random.RandomState(seed)

    working_graphs = [graph.clone() for graph in graphs]

    target_new_graphs = int(len(graphs) * augment_percentage)

    augmented_graphs = []
    augmented_count = 0


    while augmented_count < target_new_graphs:
        try:
            iteration_augmented: List[Data] = []

            for method in augmentation_methods:
                if method == "edge_drop":
                    graph_to_augment = rng.randint(low=0, high=len(working_graphs))
                    original_graph = working_graphs[graph_to_augment]

                    augmented_graph = edge_dropping(
                        original_graph, 
                        drop_rate=edge_drop_rate
                    )
                elif method == "node_drop":
                    graph_to_augment = rng.randint(low=0, high=len(working_graphs))
                    original_graph = working_graphs[graph_to_augment]

                    augmented_graph = node_dropping(
                        original_graph, 
                        drop_rate=node_drop_rate
                    )
                elif method == "feature_mask":
                    graph_to_augment = rng.randint(low=0, high=len(working_graphs))
                    original_graph = working_graphs[graph_to_augment]


                    augmented_graph = feature_masking(
                        original_graph, 
                        mask_rate=feature_mask_rate,
                    )
                elif method == "edge_perturb":
                    graph_to_augment = rng.randint(low=0, high=len(working_graphs))
                    original_graph = working_graphs[graph_to_augment]

                    augmented_graph = edge_perturbation(
                        original_graph, 
                        add_rate=edge_add_rate,
                        remove_rate=edge_remove_rate
                    )

                augmented_graph.augmentation_method = method
                augmented_graph.parent_idx = graph_to_augment

                iteration_augmented.append(augmented_graph)

            unique_augmented = iteration_augmented[:target_new_graphs - augmented_count]

            for aug_graph in unique_augmented:
                augmented_graphs.append(aug_graph)
                augmented_count += 1

                if augmented_count >= target_new_graphs:
                    break

            if augmented_count >= target_new_graphs:
                break

        except Exception as e:
            print(f"Error during augmentation: {e}")
            continue

    all_graphs = working_graphs + augmented_graphs

    print(f"Augmenting finished: {len(working_graphs)} originals + {len(augmented_graphs)} augmented = {len(all_graphs)} total")

    return all_graphs

Usage Examples

Basic Graph Augmentation

from augchem.modules.graph.graphs_modules import augment_dataset
import torch
from torch_geometric.data import Data

# Example: Create sample molecular graphs
graphs = [
    Data(x=torch.randn(10, 5), edge_index=torch.randint(0, 10, (2, 20))),
    Data(x=torch.randn(8, 5), edge_index=torch.randint(0, 8, (2, 16)))
]

# Apply multiple augmentation techniques
augmented_graphs = augment_dataset(
    graphs=graphs,
    augmentation_methods=['edge_drop', 'node_drop', 'feature_mask', 'edge_perturb'],
    edge_drop_rate=0.1,
    node_drop_rate=0.1, 
    feature_mask_rate=0.15,
    edge_add_rate=0.05,
    edge_remove_rate=0.05,
    augment_percentage=0.3,
    seed=42
)

print(f"Original: {len(graphs)} graphs")
print(f"Augmented: {len(augmented_graphs)} graphs")

Individual Augmentation Techniques

from augchem.modules.graph.graphs_modules import (
    edge_dropping, node_dropping, feature_masking, edge_perturbation
)

# Apply individual techniques
graph = your_molecular_graph

# Edge dropping - removes bidirectional connections
graph_edge_drop = edge_dropping(graph, drop_rate=0.1)

# Node dropping - removes nodes and their connections
graph_node_drop = node_dropping(graph, drop_rate=0.1)

# Feature masking - masks node features with -inf
graph_feature_mask = feature_masking(graph, mask_rate=0.15)

# Edge perturbation - adds and removes edges
graph_perturbed = edge_perturbation(
    graph, 
    add_rate=0.05, 
    remove_rate=0.05
)

Working with PyTorch Geometric DataLoaders

from torch_geometric.loader import DataLoader

# Create DataLoader with augmented graphs
dataloader = DataLoader(
    augmented_graphs,
    batch_size=32,
    shuffle=True
)

# Process batches
for batch in dataloader:
    print(f"Batch size: {batch.num_graphs}")
    print(f"Total nodes: {batch.x.size(0)}")
    print(f"Total edges: {batch.edge_index.size(1)}")
    break

Technical Notes

Graph Integrity

  • All augmentation functions preserve graph structure validity
  • Node indices are properly remapped after node dropping
  • Edge attributes are handled consistently across operations

Bidirectional Edges

  • Edge dropping and perturbation work with complete bidirectional edges
  • This ensures molecular graph connectivity is maintained properly
  • Single-direction edge operations would break chemical bond representation

Feature Masking

  • Uses -inf as mask value for compatibility with attention mechanisms
  • Masked features can be easily identified and handled in downstream models
  • Preserves tensor shapes for batch processing

Reproducibility

  • All augmentation functions support random seed control
  • Deterministic results for the same input parameters and seed
  • Essential for experimental reproducibility in research

Memory Efficiency

  • All functions create cloned graphs to preserve originals
  • Efficient tensor operations using PyTorch primitives
  • Batch processing optimized for GPU acceleration