Adapted to Movie lens dataset
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.sh 1.2KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. first_embedding_dim=32
  2. second_embedding_dim=16
  3. z1_dim=32
  4. z2_dim=32
  5. z_dim=32
  6. enc_h1_dim=32
  7. enc_h2_dim=16
  8. taskenc_h1_dim=32
  9. taskenc_h2_dim=32
  10. taskenc_final_dim=16
  11. clusters_k=10
  12. temperature=1.0
  13. lambda=1.0
  14. dec_h1_dim=32
  15. dec_h2_dim=32
  16. dec_h3_dim=16
  17. dropout_rate=0
  18. lr=0.0001
  19. optim='adam'
  20. num_epoch=100
  21. batch_size=32
  22. train_ratio=0.7
  23. valid_ratio=0.1
  24. support_size=20
  25. query_size=10
  26. max_len=200
  27. context_min=20
  28. CUDA_VISIBLE_DEVICES=0 python train_TaNP.py \
  29. --first_embedding_dim $first_embedding_dim \
  30. --second_embedding_dim $second_embedding_dim \
  31. --z1_dim $z1_dim \
  32. --z2_dim $z2_dim \
  33. --z_dim $z_dim \
  34. --enc_h1_dim $enc_h1_dim \
  35. --enc_h2_dim $enc_h2_dim \
  36. --taskenc_h1_dim $taskenc_h1_dim \
  37. --taskenc_h2_dim $taskenc_h2_dim \
  38. --taskenc_final_dim $taskenc_final_dim \
  39. --clusters_k $clusters_k \
  40. --lambda $lambda \
  41. --temperature $temperature \
  42. --dec_h1_dim $dec_h1_dim \
  43. --dec_h2_dim $dec_h2_dim \
  44. --dec_h3_dim $dec_h3_dim \
  45. --lr $lr \
  46. --dropout_rate $dropout_rate \
  47. --optim $optim \
  48. --num_epoch $num_epoch \
  49. --batch_size $batch_size \
  50. --train_ratio $train_ratio \
  51. --valid_ratio $valid_ratio \
  52. --support_size $support_size \
  53. --query_size $query_size \
  54. --max_len $max_len \
  55. --context_min $context_min