Browse Source

train function

master
Yassaman Ommi 3 years ago
parent
commit
16a82324f1
1 changed files with 5 additions and 1 deletions
  1. 5
    1
      GraphTransformer.py

+ 5
- 1
GraphTransformer.py View File



return mlp_output return mlp_output


# train
""""
Train the Model
Prepare data using DataLoader
(data can't be batched)
"""


def build_model(gcn_input, model_dim, head): def build_model(gcn_input, model_dim, head):
model = Hydra(gcn_input, model_dim, head).cuda() model = Hydra(gcn_input, model_dim, head).cuda()

Loading…
Cancel
Save