Instructions to use Cainiao-AI/TAAS with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Cainiao-AI/TAAS with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Cainiao-AI/TAAS", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Cainiao-AI/TAAS", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| #! python3 | |
| # -*- encoding: utf-8 -*- | |
| import torch | |
| import torch.nn.functional as F | |
| import pandas as pd | |
| import sys | |
| import os | |
| from transformers.utils.hub import cached_file | |
| resolved_module_file = cached_file( | |
| 'Cainiao-AI/TAAS', | |
| 'htc_mask_dict_old.pkl' | |
| ) | |
| htc_weights = [0.067, 0.133, 0.2, 0.267, 0.333] | |
| htc_mask_dict = pd.read_pickle(resolved_module_file) | |
| import numpy as np | |
| import operator | |
| def calculate_multi_htc_acc_batch(predicted_htc, y, sequence_len = 6): | |
| acc_cnt = np.array([0, 0, 0, 0, 0]) | |
| y = y.view(-1, sequence_len, 5).tolist() | |
| predicted = np.array(predicted_htc).reshape(-1, sequence_len, 5).tolist() | |
| batch_size = len(y) | |
| total_cnt = np.array([0, 0, 0, 0, 0]) | |
| for batch_i in range(batch_size): | |
| for index, s2 in enumerate(y[batch_i]): | |
| for c, i in enumerate(range(5)): | |
| y_l10 = y[batch_i][index][:i+1] | |
| p_l10 = predicted[batch_i][index][:i+1] | |
| if -100 in y_l10: | |
| break | |
| if operator.eq(y_l10, p_l10): | |
| acc_cnt[c] += 1 | |
| total_cnt[c] += 1 | |
| return acc_cnt, total_cnt | |
| class HTCLoss(torch.nn.Module): | |
| def __init__(self, device, reduction='mean', using_htc = True): | |
| super(HTCLoss, self).__init__() | |
| self.reduction = reduction | |
| self.htc_weights = htc_weights | |
| self.device = device | |
| self.using_htc = using_htc | |
| self.htc_mask_dict = htc_mask_dict | |
| for key, value in self.htc_mask_dict.items(): | |
| # self.htc_mask_dict[key] = torch.tensor(value).to(self.device) | |
| self.htc_mask_dict[key] = torch.tensor(value).clone().detach().to(self.device) | |
| def forward(self, logits, target): # [bs,num_class] CE=q*-log(p), q*log(1-p),p=softmax(logits) | |
| # target相关变量都在cuda上 | |
| target = target.reshape(-1, 1) | |
| target_mask = target != -100 | |
| target_mask = target_mask.squeeze() | |
| target_mask_idx = torch.where(target == -100) | |
| target_new = target.clone() | |
| target_new[target_mask_idx] = 0 | |
| predict_res = [] | |
| if not self.using_htc: | |
| log_pro = -1.0 * F.log_softmax(logits, dim=1) | |
| # one_hot = torch.zeros(logits.shape[0], logits.shape[1]).to(self.device) # .cuda() | |
| # one_hot = one_hot.scatter_(1, target_new, 1) | |
| # loss = torch.mul(log_pro, one_hot).sum(dim=1) | |
| # loss = loss*target_mask | |
| else: | |
| # _, predicted = torch.max(logits[:, :32], 1) | |
| logits_reshaped = logits.clone() | |
| logits_reshaped = logits_reshaped.reshape(-1, 5, 100) | |
| _, aa_predicted = torch.max(logits_reshaped[:,0,1:32], 1) | |
| aa_predicted += 1 | |
| logits_new = -5 * torch.ones_like(logits_reshaped).to(self.device) | |
| logits_new[:,0,1:32] = logits_reshaped[:,0,1:32] | |
| for sample_idx, aa in enumerate(aa_predicted): | |
| bb_idx = htc_mask_dict['{:02d}'.format(aa)] | |
| _, bb_idy = torch.max(logits_reshaped[sample_idx,1,bb_idx], 0) | |
| bb = bb_idx[bb_idy] | |
| logits_new[sample_idx,1,bb_idx] = logits_reshaped[sample_idx,1,bb_idx] | |
| cc_idx = htc_mask_dict['{:02d}{:02d}'.format(aa, bb)] | |
| _, cc_idy = torch.max(logits_reshaped[sample_idx,2,cc_idx], 0) | |
| logits_new[sample_idx,2,cc_idx] = logits_reshaped[sample_idx,2,cc_idx] | |
| cc = cc_idx[cc_idy] | |
| d_idx = htc_mask_dict['{:02d}{:02d}{:02d}'.format(aa, bb, cc)] | |
| _, d_idy = torch.max(logits_reshaped[sample_idx,3,d_idx], 0) | |
| logits_new[sample_idx,3,d_idx] = logits_reshaped[sample_idx,3,d_idx] | |
| d = d_idx[d_idy] | |
| ee_idx = htc_mask_dict['{:02d}{:02d}{:02d}{:01d}'.format(aa, bb, cc, d)] | |
| _, ee_idy = torch.max(logits_reshaped[sample_idx,4,ee_idx], 0) | |
| logits_new[sample_idx,4,ee_idx] = logits_reshaped[sample_idx,4,ee_idx] | |
| ee = ee_idx[ee_idy] | |
| predict_res.extend([aa.item(), bb.item(), cc.item(), d.item(), ee.item()]) | |
| # predicted = predicted.reshape(-1, 5) | |
| # aa = predicted[:, 0] | |
| # aa = ['{:02d}'.format(i) for i in aa] | |
| # bb_activate = [htc_mask_dict[i] for i in aa] | |
| logits_new = logits_new.reshape(-1, 100) | |
| log_pro = -1.0 * F.log_softmax(logits_new, dim=1) | |
| logits = logits.contiguous().view(-1, 100) | |
| one_hot = torch.zeros(logits.shape[0], logits.shape[1]).to(self.device) # .cuda() | |
| one_hot = one_hot.scatter_(1, target_new, 1) | |
| loss = torch.mul(log_pro, one_hot).sum(dim=1) | |
| loss = loss*target_mask | |
| bs = int(loss.shape[0] / 5) | |
| w_loss = [] | |
| for i in range(bs): | |
| w_loss.extend(self.htc_weights) | |
| w_loss = torch.FloatTensor(w_loss).to(self.device) | |
| loss = loss.mul(w_loss) * 5 | |
| if self.reduction == 'mean': | |
| loss = loss[torch.where(loss>0)].mean() | |
| elif self.reduction == 'sum': | |
| loss = loss[torch.where(loss>0)].sum() | |
| return loss, predict_res | |
| def get_htc_code(self, logits): # [bs,num_class] CE=q*-log(p), q*log(1-p),p=softmax(logits) | |
| logits_reshaped = logits.clone() | |
| logits_reshaped = logits_reshaped.reshape(-1, 5, 100) | |
| _, aa_predicted = torch.max(logits_reshaped[:,0,1:32], 1) | |
| aa_predicted += 1 | |
| logits_new = -5 * torch.ones_like(logits_reshaped).to(self.device) | |
| logits_new[:,0,1:32] = logits_reshaped[:,0,1:32] | |
| predict_res = [] | |
| for sample_idx, aa in enumerate(aa_predicted): | |
| bb_idx = htc_mask_dict['{:02d}'.format(aa)] | |
| _, bb_idy = torch.max(logits_reshaped[sample_idx,1,bb_idx], 0) | |
| bb = bb_idx[bb_idy] | |
| logits_new[sample_idx,1,bb_idx] = logits_reshaped[sample_idx,1,bb_idx] | |
| cc_idx = htc_mask_dict['{:02d}{:02d}'.format(aa, bb)] | |
| _, cc_idy = torch.max(logits_reshaped[sample_idx,2,cc_idx], 0) | |
| logits_new[sample_idx,2,cc_idx] = logits_reshaped[sample_idx,2,cc_idx] | |
| cc = cc_idx[cc_idy] | |
| d_idx = htc_mask_dict['{:02d}{:02d}{:02d}'.format(aa, bb, cc)] | |
| _, d_idy = torch.max(logits_reshaped[sample_idx,3,d_idx], 0) | |
| logits_new[sample_idx,3,d_idx] = logits_reshaped[sample_idx,3,d_idx] | |
| d = d_idx[d_idy] | |
| ee_idx = htc_mask_dict['{:02d}{:02d}{:02d}{:01d}'.format(aa, bb, cc, d)] | |
| _, ee_idy = torch.max(logits_reshaped[sample_idx,4,ee_idx], 0) | |
| logits_new[sample_idx,4,ee_idx] = logits_reshaped[sample_idx,4,ee_idx] | |
| ee = ee_idx[ee_idy] | |
| predict_res.extend([aa.item(), bb.item(), cc.item(), d.item(), ee.item()]) | |
| return predict_res | |