Browse Source

train function

master
Yassaman Ommi 3 years ago
parent
commit
c165618760
1 changed files with 2 additions and 2 deletions
  1. 2
    2
      GraphTransformer.py

+ 2
- 2
GraphTransformer.py View File

@@ -210,9 +210,9 @@ class Hydra(nn.Module):

return mlp_output

''''
""""
hala train
''''
""""
def build_model(gcn_input, model_dim, head):
model = Hydra(gcn_input, model_dim, head).cuda()
return model

Loading…
Cancel
Save