@@ -210,7 +210,11 @@ class Hydra(nn.Module): | |||
return mlp_output | |||
# train | |||
"""" | |||
Train the Model | |||
Prepare data using DataLoader | |||
(data can't be batched) | |||
""" | |||
def build_model(gcn_input, model_dim, head): | |||
model = Hydra(gcn_input, model_dim, head).cuda() |