{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "e6ecf439-a0db-42e0-a6b9-f512198b0e0e", "metadata": { "tags": [] }, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "code", "execution_count": 4, "id": "4bcc7c7e-711a-4cd9-b901-d6ff76938a75", "metadata": { "tags": [] }, "outputs": [], "source": [ "best_path = '/home/msadraei/trained_final/iclr_resp_t5_small_glue-cola/10_attempt/best.pt'\n", "first_path = '/home/msadraei/trained_final/iclr_resp_t5_small_glue-cola/10_attempt/first.pt'" ] }, { "cell_type": "code", "execution_count": 5, "id": "eaa4a300-1e6c-46f0-8f0d-16e9c71c2388", "metadata": { "tags": [] }, "outputs": [], "source": [ "best = torch.load(best_path)\n", "first = torch.load(first_path)" ] }, { "cell_type": "code", "execution_count": 8, "id": "c5e0b6bb-3bde-4526-8a6a-5dac0a3b3cc3", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "sadcl_p_target\n", "tensor(42.7208, device='cuda:0')\n", "pretrained_tasks\n", "tensor(0., device='cuda:0')\n", "sadcl_attention_score.g_network.0.weight\n", "tensor(157.3032, device='cuda:0')\n", "sadcl_attention_score.g_network.2.weight\n", "tensor(154.6590, device='cuda:0')\n", "sadcl_attention_score.g_network.3.weight\n", "tensor(18.1127, device='cuda:0')\n", "sadcl_attention_score.g_network.3.bias\n", "tensor(19.0149, device='cuda:0')\n" ] } ], "source": [ "for key in best.keys():\n", " print(key)\n", " v1 = first[key]\n", " v2 = best[key]\n", " print(torch.norm(v1 - v2))" ] }, { "cell_type": "code", "execution_count": 13, "id": "42815cf2-b8bf-4219-a3fd-ebbe92fb5c32", "metadata": {}, "outputs": [], "source": [ "base_path = '/home/msadraei/trained_final/forward_transfer_test_t5_base_superglue-rte/10_combine_128_4tasks_new_impl_tie_50/100'\n", "last_path = f'{base_path}/last.pt'\n", "best_path = f'{base_path}/best.pt'\n", "first_path = f'{base_path}/first.pt'" ] }, { "cell_type": "code", "execution_count": 14, "id": "880cb651-ddea-4564-93ab-c5f52e1f02dd", "metadata": { "tags": [] }, "outputs": [], "source": [ "import torch\n", "last = torch.load(last_path)\n", "best = torch.load(best_path)\n", "first = torch.load(first_path)" ] }, { "cell_type": "code", "execution_count": 15, "id": "ee4b3287-203f-49b0-8b89-6070f9ff4062", "metadata": { "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "def pretrained_coeff(state_dict):\n", " return np.stack([\n", " val.cpu().numpy()\n", " for key, val in state_dict.items()\n", " if 'sadcl_coeff_pretrained' in key\n", " ])" ] }, { "cell_type": "code", "execution_count": 16, "id": "26518ecd-8cc1-4543-acaf-56637295bbe8", "metadata": { "tags": [] }, "outputs": [], "source": [ "last_coeff = pretrained_coeff(best)\n", "best_coeff = pretrained_coeff(best)\n", "first_coeff = pretrained_coeff(first)" ] }, { "cell_type": "code", "execution_count": 17, "id": "5a850a65-724a-483d-abb3-b7de6118db31", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "array([[0.43, 0.42, 0.42, 0.42],\n", " [0.43, 0.42, 0.42, 0.42],\n", " [0.43, 0.42, 0.42, 0.42],\n", " [0.43, 0.42, 0.42, 0.42],\n", " [0.43, 0.42, 0.42, 0.42],\n", " [0.43, 0.42, 0.42, 0.42],\n", " [0.43, 0.42, 0.42, 0.42],\n", " [0.43, 0.42, 0.42, 0.42],\n", " [0.43, 0.42, 0.42, 0.42],\n", " [0.43, 0.42, 0.42, 0.42]], dtype=float32)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.round(last_coeff/ 100 , 2)\n" ] }, { "cell_type": "code", "execution_count": 65, "id": "7182b595-5bb3-4c06-88dc-1f50ed774500", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(34.9105)" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.linalg.vector_norm(torch.Tensor(best_coeff[0]), ord=1)" ] }, { "cell_type": "code", "execution_count": null, "id": "9e2a2080-9450-4df2-b20e-4619e3f92c1b", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:deep]", "language": "python", "name": "conda-env-deep-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 5 }