You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

WeightedSum.c 5.2KB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. #include "mex.h"
  2. void weightedSum(mwIndex* ir1, mwIndex* jc1, double* aff1,
  3. mwIndex* ir2, mwIndex* jc2, double* aff2,
  4. mwSize n, mwSize end, double w,
  5. mwIndex* ir_out, mwIndex* jc_out, double* outSimAtts)
  6. {
  7. int nnz = 0;
  8. int currCol, currRow;
  9. int i1, r1, startIdx1, stopIdx1;
  10. int i2, r2, startIdx2, stopIdx2;
  11. double val, val1, val2;
  12. int MAX = 2147483647;
  13. printf("starting weightedSum ...\n");
  14. for (currCol = 0; currCol < n; currCol++) {
  15. jc_out[currCol] = nnz;
  16. startIdx1 = (int)jc1[currCol];
  17. stopIdx1 = (int)jc1[currCol + 1];
  18. i1 = startIdx1;
  19. startIdx2 = (int)jc2[currCol];
  20. stopIdx2 = (int)jc2[currCol + 1];
  21. i2 = startIdx2;
  22. // printf("weightedSum: currCol = %d: Idx1=[%d,%d], Idx2=[%d,%d] ...\n",
  23. // currCol,startIdx1,stopIdx1,startIdx2,stopIdx2);
  24. while (i1 < stopIdx1 || i2 < stopIdx2) {
  25. if (i1 < stopIdx1)
  26. r1 = (int)ir1[i1];
  27. else
  28. r1 = MAX;
  29. if (i2 < stopIdx2)
  30. r2 = (int)ir2[i2];
  31. else
  32. r2 = MAX;
  33. // printf("weightedSum:currCol = %d, r1=%d, r2=%d ...\n",currCol, r1,r2);
  34. if (r1 < r2) {
  35. currRow = r1;
  36. val1 = aff1[i1];
  37. val2 = 0;
  38. i1++;
  39. } else if (r2 < r1) {
  40. currRow = r2;
  41. val1 = 0;
  42. val2 = aff2[i2];
  43. i2++;
  44. } else { // r1==r2
  45. currRow = r1;
  46. val1 = aff1[i1];
  47. val2 = aff2[i2];
  48. i1++; i2++;
  49. }
  50. // printf("weightedSum:currCol = %d, currRow = %d ...\n",currCol, currRow);
  51. if (currCol < end && currRow < end) {
  52. val = val1*(1-w)+val2*w;
  53. } else {
  54. val = val1;
  55. }
  56. if (val > 0) {
  57. outSimAtts[nnz] = val;
  58. ir_out[nnz] = currRow;
  59. nnz++;
  60. if (nnz % 1000000 == 0)
  61. printf("weightedSum nnz=%d M\n", nnz/1000000);
  62. }
  63. }
  64. }
  65. jc_out[n] = nnz;
  66. printf("end weightedSum => nnz=%d\n", nnz);
  67. }
  68. /* The gateway function
  69. //this function accepts two parameters:
  70. //1. a sparse structure affinity matrix
  71. //2. a sparse structure affinity matrix
  72. //3. weight
  73. //4. end index
  74. */
  75. void mexFunction( int nlhs, mxArray *plhs[],
  76. int nrhs, const mxArray *prhs[])
  77. {
  78. /* variable declarations here */
  79. mwIndex *ir1, *jc1;
  80. mwIndex *ir2, *jc2;
  81. double *aff1, *aff2;
  82. mwSize n,m,n2,m2;
  83. double w;
  84. int end;
  85. int nnz;
  86. mwIndex* ir_out;
  87. mwIndex* jc_out;
  88. double* outSimAtts;
  89. mwSize estimateNNZ;
  90. double percent_sparse = 1; //0.75;
  91. printf("weightedSum: nlhs = %d, nrhs = %d\n", nlhs, nrhs);
  92. /* Check for the correct number of outputs */
  93. if(nlhs != 1){
  94. mexErrMsgIdAndTxt( "MATLAB:weightedSum:wrongrhs",
  95. "Wrong number of output arguments.");
  96. }
  97. /* Check for the correct number of inputs */
  98. if(nrhs != 4){
  99. mexErrMsgIdAndTxt( "MATLAB:weightedSum:wrongrhs",
  100. "Wrong number of input arguments.");
  101. }
  102. aff1 = mxGetPr(prhs[0]); /* pointer to first input */
  103. ir1 = mxGetIr(prhs[0]);
  104. jc1 = mxGetJc(prhs[0]);
  105. aff2 = mxGetPr(prhs[1]); /* pointer to second input */
  106. ir2 = mxGetIr(prhs[1]);
  107. jc2 = mxGetJc(prhs[1]);
  108. w = mxGetScalar(prhs[2]); /* pointer to thrid input */
  109. end = (int)mxGetScalar(prhs[3]); /* pointer to fourth input */
  110. printf("weightedSum: w=%f, end=%d\n", w, end);
  111. /* dimensions of input matrices */
  112. m = mxGetM(prhs[0]);
  113. n = mxGetN(prhs[0]);
  114. m2 = mxGetM(prhs[1]);
  115. n2 = mxGetN(prhs[1]);
  116. printf("weightedSum dimensions: m=%d, n=%d, m2=%d, n2=%d\n", m, n, m2, n2);
  117. if (m != m2 || n != n2) {
  118. mexErrMsgIdAndTxt("MATLAB:attributesSimilarity:matchdims",
  119. "Inner dimensions of matrices do not match.");
  120. }
  121. if (n != m) {
  122. mexErrMsgIdAndTxt("MATLAB:attributesSimilarity:square",
  123. "Function requires input matrix 1 must be square.");
  124. }
  125. nnz = jc1[n]+jc2[n];
  126. printf("weightedSum: nnz=%d, percent_sparse=%f, mul=%f\n",
  127. nnz, percent_sparse, percent_sparse*nnz);
  128. estimateNNZ = (mwSize)(percent_sparse*nnz);
  129. printf("weightedSum: jc1[n]=%d, jc2[n]=%d, estimateNNZ=%d\n",
  130. jc1[n], jc2[n], estimateNNZ);
  131. // prepare output: nxn
  132. // allocate memeory according to estimateNNZ
  133. plhs[0] = mxCreateSparse(n,n, estimateNNZ, mxREAL);
  134. outSimAtts = mxGetPr(plhs[0]);
  135. ir_out = mxGetIr(plhs[0]);
  136. jc_out = mxGetJc(plhs[0]);
  137. // calculate the weighted sum
  138. // affinity(1:e, 1:e)= affinity(1:e,1:e)*(1-attWeight)+attAffinity(1:e,1:e)*attWeight;
  139. weightedSum(ir1, jc1, aff1, ir2, jc2, aff2, n, end, w, ir_out, jc_out, outSimAtts);
  140. }