Git Product home page Git Product logo

directau's Introduction

DirectAU

illustration

Implementation of the paper "Towards Representation Alignment and Uniformity in Collaborative Filtering" in KDD'22.

This work investigates the desired properties of representations in collaborative filtering (CF) from the perspective of alignment and uniformity. The proposed DirectAU provides a new learning objective for CF-based recommender systems, which directly optimizes representation alignment and uniformity on the hypersphere. A simple MF encoder optimizing this loss can achieve superior performance compared to SOTA CF methods.

Training with DirectAU

This learning objective is easy to implement as follows (PyTorch-style):

@staticmethod
def alignment(x, y):
    x, y = F.normalize(x, dim=-1), F.normalize(y, dim=-1)
    return (x - y).norm(p=2, dim=1).pow(2).mean()

@staticmethod
def uniformity(x):
    x = F.normalize(x, dim=-1)
    return torch.pdist(x, p=2).pow(2).mul(-2).exp().mean().log()

def calculate_loss(self, user, item):
    user_e, item_e = self.encoder(user, item)  # [bsz, dim]
    align = self.alignment(user_e, item_e)
    uniform = (self.uniformity(user_e) + self.uniformity(item_e)) / 2
    loss = align + self.gamma * uniform
    return loss

We integrate our DirectAU method (directau.py) into the RecBole framework. The datasets used in the paper are already included in the dataset folder. Related experimental settings can be found in the properties folder. To reproduce the results, you can run the following commands after installing all the requirements:

# Beauty
python run_recbole.py \
    --model=DirectAU --dataset=Beauty \
    --learning_rate=1e-3 --weight_decay=1e-6 \
    --gamma=0.5 --encoder=MF --train_batch_size=256

# Gowalla
python run_recbole.py \
    --model=DirectAU --dataset=Gowalla \
    --learning_rate=1e-3 --weight_decay=1e-6 \
    --gamma=5 --encoder=MF --train_batch_size=1024

# Yelp2018
python run_recbole.py \
    --model=DirectAU --dataset=Yelp \
    --learning_rate=1e-3 --weight_decay=1e-6 \
    --gamma=1 --encoder=MF --train_batch_size=1024

To test DirectAU on other datasets, you should prepare datasets similar to the existing ones. More explanations about the dataset format can be found in the Atomic Files of RecBole.

The main hyper-parameters of DirectAU includes:

Param Default Description
--embedding_size 64 The embedding size.
--gamma 1 The weight of the uniformity loss.
--encoder MF The encoder type: MF / LightGCN
--n_layers None The number of layers when --encoder=LightGCN

You can use the following command to tune hyper-parameters in DirectAU (more details see Parameter Tuning in RecBole):

python run_hyper.py \
    --model=DirectAU --dataset=Beauty \
    --config_files='recbole/properties/overall.yaml recbole/properties/model/DirectAU.yaml recbole/properties/dataset/sample.yaml' \
    --params_file=directau.hyper \
    --output_file=hyper.result

Measuring Alignment and Uniformity

The measurement of alignment and uniformity given the learned representations can be implemented as follows (Appendix A.2 in the paper):

def overall_align(user_index, item_index, user_emb, item_emb):
    """ Args:
    user_index (torch.LongTensor): user ids of positive interactions, shape: [|R|, ]
    item_index (torch.LongTensor): item ids of positive interactions, shape: [|R|, ]
    user_emb (torch.nn.Embedding): user embeddings of all the users, shape: [|U|, dim]
    item_emb (torch.nn.Embedding): item embeddings of all the items, shape: [|I|, dim]
    """
    x = F.normalize(user_emb[user_index], dim=-1)
    y = F.normalize(item_emb[item_index], dim=-1)
    return (x - y).norm(p=2, dim=1).pow(alpha).mean()

def overall_uniform(index_list, embedding):
    """ Args:
    index_list (torch.LongTensor): user/item ids of positive interactions, shape: [|R|, ]
    embedding (torch.nn.Embedding): user/item embeddings, shape: [|U|, dim] or [|I|, dim]
    """ 
    values, _= torch.sort(index_list)
    count_series = pd.value_counts(values.tolist(), sort=False)
    count = torch.from_numpy(count_series.values).unsqueeze(0)

    weight_matrix = torch.mm(count.transpose(-1, 0), count)
    weight = torch.triu(weight_matrix, 1).view(-1)[
        torch.nonzero(torch.triu(weight_matrix, 1).view(-1)).view(-1)].to(embedding.device)
    total_freq = (len(index_list) * len(index_list) - weight_matrix.trace()) / 2

    return torch.pdist(embedding[count_series.index], p=2).pow(2).mul(-2).exp().mul(weight).sum().div(total_freq).log()

measurement

Citation

If you find this work is helpful to your research, please consider citing our paper:

@inproceedings{wang2022towards,
  title={Towards Representation Alignment and Uniformity in Collaborative Filtering},
  author={Wang, Chenyang and Yu, Yuanqing and Ma, Weizhi and Zhang, Min and Chen, Chong and Liu, Yiqun and Ma, Shaoping},
  booktitle={Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
  pages={1816--1825},
  year={2022}
}

Contact

Chenyang Wang ([email protected])

directau's People

Contributors

thuwangcy avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.