You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

cascade_embedding.py 3.4KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import json
  2. import logging
  3. import os
  4. from sentence_transformers import SentenceTransformer
  5. import numpy as np
  6. import pandas as pd
  7. from sklearn.decomposition import PCA
  8. import pickle
  9. ROOT_DIR = os.path.dirname(os.path.realpath(__file__))
  10. model = SentenceTransformer('paraphrase-distilroberta-base-v1')
  11. logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', filename='pca_test.log', level=logging.INFO,
  12. datefmt='%Y-%m-%d %H:%M:%S')
  13. def embed_cascades_text():
  14. directory = os.fsencode(os.path.join(ROOT_DIR, 'twitter-raw-data/Twitter/'))
  15. cascades_text_embeded = {}
  16. for cascade_name in os.listdir(directory):
  17. cascade_dir = os.fsencode(os.path.join(ROOT_DIR, 'twitter-raw-data/Twitter', str(cascade_name).split('\'')[1]))
  18. cascade_embed = np.array([0.0 for _ in range(768)])
  19. ctr = 0
  20. for cascade in os.listdir(cascade_dir):
  21. fd = open(os.path.join(str(cascade_dir).split('\'')[1], str(cascade).split('\'')[1]), 'r')
  22. cascade_tweet_text = fd.read()
  23. fd.close()
  24. print(cascade_tweet_text)
  25. cascade_tweet_json = json.loads(cascade_tweet_text)
  26. cascade_tweet_text = cascade_tweet_json['tweet']['text']
  27. tweet_embed = model.encode(cascade_tweet_text)
  28. cascade_embed = np.add(cascade_embed, tweet_embed)
  29. ctr += 1
  30. cascade_embed = cascade_embed / ctr
  31. cascades_text_embeded[str(cascade_name).split('\'')[1]] = cascade_embed
  32. print(cascade_tweet_text)
  33. return cascades_text_embeded
  34. def embed_user_cascades(user_cascades):
  35. fake_casc_percent = 0
  36. user_tweets_embed = np.array([0 for _ in range(768)])
  37. global model
  38. for cascade in user_cascades:
  39. fake_casc_percent += int(cascade[1])
  40. user_tweets_embed = np.add(user_tweets_embed, model.encode(cascade[2]))
  41. user_tweets_embed = user_tweets_embed / len(user_cascades)
  42. fake_casc_percent = fake_casc_percent / len(user_cascades)
  43. return user_tweets_embed, fake_casc_percent
  44. def user_embedding(users_dict: dict):
  45. global model
  46. users_bio = pd.DataFrame()
  47. user_ids = []
  48. users_bio = None
  49. users_tweets_embed = None
  50. logging.info("start embedding.")
  51. for user_id, user_info in users_dict.items():
  52. user_ids += user_id
  53. print(user_info)
  54. user_bio_embed = model.encode(user_info['profile_features']['description'])
  55. user_tweets_embed, fake_casc_percent = embed_user_cascades(user_info['cascades_feature'])
  56. if users_bio is None:
  57. users_bio = [user_bio_embed.tolist()]
  58. users_tweets_embed = [user_tweets_embed.tolist()]
  59. else:
  60. users_bio = np.append(users_bio, [user_bio_embed.tolist()], axis=0)
  61. users_tweets_embed = np.append(users_tweets_embed, [user_tweets_embed.tolist()], axis=0)
  62. pca = PCA(n_components=200)
  63. users_bio = pca.fit_transform(users_bio)
  64. users_tweets_embed = pca.fit_transform(users_tweets_embed)
  65. logging.info("users bio: {0}".format(users_bio))
  66. logging.info("users tweets embed".format(users_tweets_embed))
  67. user_dict = pickle.load(open(os.path.join(ROOT_DIR, '../../Representations/users_data.p'), "rb"))
  68. user_embedding(user_dict)
  69. # user_embedding({'12': {'description':'hi to you', 'cascades_feature':[[12, 1, 'this is a test']]},'13': {'description':'hi to me', 'cascades_feature':[[12, 1, 'this is not a test']]}})