Git Product home page Git Product logo

stylemotion's Introduction

StyleMotion

Our code is based on Glow and MoGlow.

Prerequisites

The conda environment defined in 'environment.yml' contains the required dependencies.

Data & Pretrained Model

Our training data is available here.

Our pretrained model is available here.

Prepare the environment by: conda env create -f environment.yml

Download and extract the data and pretained model to the project.

For training:

python train.py hparams/locomotion.json locomotion.

For style transfer, , and then run:

python train.py hparams/locomotion_test.json locomotion

License

Please see the included LICENSE file for licenses and citations.

stylemotion's People

Contributors

wenyh1616 avatar

Stargazers

Guangtao Lyu ( 吕光涛 ) avatar  avatar Seder(方进) avatar  avatar  avatar Boeun Kim avatar Okrin avatar Amelia Young avatar  avatar  avatar Jingbo  avatar Fukahire avatar  avatar  avatar  avatar  avatar  avatar Snow avatar  avatar Ruihan Yang avatar  avatar  avatar  avatar  avatar Yueren avatar Lingxiao Zhang avatar  avatar Jiaqi-Zhang avatar Shuyu Chen avatar lan avatar Jie Yang avatar

Watchers

Snow avatar  avatar

stylemotion's Issues

Broken links

Hello, I'm interested in your research 👍
I think the download links of the pretrained weights and the dataset are broken.
Could you provide it again?
Thank you very much.

'Glow' object has no attribute 'init_lstm_hidden'

Great job! But I got error "'Glow' object has no attribute 'init_lstm_hidden'" when running the code. Is the code partially missing? Hope to get your reply, thanks!

Traceback (most recent call last):
File "train.py", line 51, in
trainer.train()
File "StyleMotion/glow/trainer.py", line 389, in train
self.graph.init_lstm_hidden()
File "anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 594, in getattr
type(self).name, name))
AttributeError: 'Glow' object has no attribute 'init_lstm_hidden'

风格特征可视化

@wenyh1616 您好,方便分享一下风格特征可视化的代码吗?我尝试用您提供的模型进行风格特征可视化,但是和论文中的结果差别较大。想知道如何得到论文中下图所示的结果呢?感谢!
wen
以下是我对风格特征的处理:

    def latent_code_visualize(self, iterations):
        def to_float(item):
            if isinstance(item, torch.Tensor):
                item = item.detach().cpu().numpy()
            if isinstance(item, np.ndarray):
                if len(item.reshape(-1)) == 1:
                    item = float(item)
            return item
        with torch.no_grad():

            """latent codes"""  # !!!!! TD: add a separate function, merge with plot_clusters ????t-????????(t-SNE)???

            vis_dicts = {}
            for phase, co_loader, writer in [
                ['train', self.data_loader, self.writer]]:

                vis_dict = None
                # ??key????list
                for t, tcl_data in enumerate(co_loader):
                    vis_codes = {}
                    x = tcl_data["x"].to(self.data_device)  # 100,63,70
                    cond = tcl_data["cond"].to(self.data_device)  # 100, 663, 70
                    vis_codes['style_code'] = self.graph.generate_z(x, cond)
                    vis_codes['meta'] = tcl_data['meta']
                    vis_codes['label'] = tcl_data['label']
                    if vis_dict is None:
                        vis_dict = {}
                        for key, value in vis_codes.items():
                            vis_dict[key] = [value]
                    else:
                        for key, value in vis_codes.items():
                            vis_dict[key].append(value)

                for key, value in vis_dict.items():
                    if key == "meta" or key == "label":
                        num = len(value)
                        vis_dict[key] = [to_float(item) for i in range(num) for item in value[i]]
                    else:
                        vis_dict[key] = torch.cat(vis_dict[key], 0)
                        vis_dict[key] = vis_dict[key].cpu().numpy()
                        vis_dict[key] = to_float(vis_dict[key].reshape(vis_dict[key].shape[0], -1))

                vis_dicts[phase] = vis_dict

            writers = {"train": self.writer}
            from visualization.latent_plot_utils import get_all_plots
            get_all_plots(vis_dicts, os.path.join(self.log_dir, '%08d' % (iterations + 1)), writers, iterations + 1)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.