# Training > 本文档详细介绍了如何使用TuGraph进行图神经网络(GNN)的训练。 ## 1. 训练 使用TuGraph 图学习模块进行训练时,可以分为全图训练和mini-batch训练。 全图训练即把全图从TuGraph db加载到内存中,再进行GNN的训练。而mini-batch训练则使用上面提到的TuGraph 图学习模块的采样算子,将全图数据进行采样后,再送入训练框架中进行训练。 ## 2. Mini-Batch训练 Mini-Batch训练需要使用TuGraph 图学习模块的采样算子,目前支持Neighbor Sampling、Edge Sampling、Random Walk Sampling和Negative Sampling。 TuGraph 图学习模块的采样算子进行采样后的结果以List的形式返回。 下面以Neighbor Sampling为例,介绍如何将采样后的结果,进行格式转换,送入到训练框架中进行训练。 用户需要提供一个Sample类: ```python class TuGraphSample(object): def __init__(self, args=None): super(TuGraphSample, self).__init__() self.args = args def sample(self, g, seed_nodes): args = self.args # 1. 加载图数据 galaxy = PyGalaxy(args.db_path) galaxy.SetCurrentUser(args.username, args.password) db = galaxy.OpenGraph(args.graph_name, False) sample_node = seed_nodes.tolist() length = args.randomwalk_length NodeInfo = [] EdgeInfo = [] # 2. 采样方法,结果存储在NodeInfo和EdgeInfo中 if args.sample_method == 'randomwalk': randomwalk.Process(db, 100, sample_node, length, NodeInfo, EdgeInfo) elif args.sample_method == 'negative': negativesample.Process(db, 100) else: neighborsample(db, 100, sample_node, args.nbor_sample_num, NodeInfo, EdgeInfo) del db del galaxy # 3. 对结果进行格式转换,使之符合训练格式 remap(EdgeInfo[0], EdgeInfo[1], NodeInfo[0]) g = dgl.graph((EdgeInfo[0], EdgeInfo[1])) g.ndata['feat'] = torch.tensor(NodeInfo[1]) g.ndata['label'] = torch.tensor(NodeInfo[2]) return g ``` 如代码所示,首先将图数据加载到内存中。然后使用采样算子对图数据进行采样,结果存储在NodeInfo和EdgeInfo中。NodeInfo和EdgeInfo是python list结果,其存储的信息结果如下: | 图数据 | 存储信息位置 | | --- | --- | | 边起点 | EdgeInfo[0] | | 边终点 | EdgeInfo[1] | | 顶点ID | NodeInfo[0] | | 顶点特征 | NodeInfo[1] | | 顶点标签 | NodeInfo[2] | 最后对结果进行格式转换,使之符合训练格式。这里我们使用的是DGL训练框架,因此使用结果数据构造了DGL Graph,最终将DGL Graph返回。 我们提供TuGraphSample类之后,就可以使用它进行Mini-Batch训练了。 令DGL的数据加载部分使用TuGraphSample的实例sampler: ```python sampler = TugraphSample(args) fake_g = construct_graph() # just make dgl happy dataloader = dgl.dataloading.DataLoader(fake_g, torch.arange(train_nids), sampler, batch_size=batch_size, device=device, use_ddp=True, num_workers=0, drop_last=False, ) ``` 使用DGL进行模型训练: ```python def train(dataloader, model): optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4) model.train() s = time.time() for graph in dataloader: load_time = time.time() graph = dgl.add_self_loop(graph) logits = model(graph, graph.ndata['feat']) loss = loss_fcn(logits, graph.ndata['label']) optimizer.zero_grad() loss.backward() optimizer.step() train_time = time.time() print('load time', load_time - s, 'train_time', train_time - load_time) s = time.time() return float(loss) ``` ## 3. 全图训练 GNN(图神经网络)的全图训练是一种涉及一次处理整个训练数据集的训练。它是 GNN 最简单、最直接的训练方法之一,整个图被视为单个实例。 在全图训练中,整个数据集被加载到内存中,模型在整个图上进行训练。这种类型的训练对于中小型图特别有用,并且主要用于不随时间变化的静态图。 在算子调用时,使用以下方式: ```python getdb.Process(db, olapondb, feature_len, NodeInfo, EdgeInfo) ``` 获取全图数据,然后将全图送入训练框架中进行训练。 完整代码:请参考learn/examples/train_full_cora.py。