Browse Source

train function

master
Yassaman Ommi 2 years ago
parent
commit
16a82324f1
1 changed files with 5 additions and 1 deletions
  1. 5
    1
      GraphTransformer.py

+ 5
- 1
GraphTransformer.py View File

@@ -210,7 +210,11 @@ class Hydra(nn.Module):

return mlp_output

# train
""""
Train the Model
Prepare data using DataLoader
(data can't be batched)
"""

def build_model(gcn_input, model_dim, head):
model = Hydra(gcn_input, model_dim, head).cuda()

Loading…
Cancel
Save