Git Product home page Git Product logo

Comments (9)

tianleiwu avatar tianleiwu commented on July 18, 2024 1

For onnx model > 2GB, it need to save like the following (Notice that last parameter is True):

onnx.save(onnx_model, "path/to/model.onnx", save_as_external_data=True)

Negative prompt is used only when CFG scale > 1.0. See example:

if do_classifier_free_guidance:
# For SD XL base, handle force_zeros_for_empty_prompt
is_empty_negative_prompt = all([not i for i in negative_prompt])
if force_zeros_for_empty_prompt and is_empty_negative_prompt:
uncond_embeddings = torch.zeros_like(text_embeddings)
if output_hidden_states:
uncond_hidden_states = torch.zeros_like(hidden_states)
else:
# Tokenize negative prompt
uncond_embeddings, uncond_hidden_states = tokenize(negative_prompt, output_hidden_states)
# Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

Negative prompt will increase "batch size" of text embedding (For example, one batch for positive prompt, another for negative prompt). You can check whether the text embedding for positive and negative is same or not. It's likely a bug in your code if you saw same positive/negative embedding or latent.

from onnxruntime.

Craigacp avatar Craigacp commented on July 18, 2024 1

You can see a roughly equivalent example in Java here - https://github.com/oracle/sd4j/blob/main/src/main/java/com/oracle/labs/mlrg/sd4j/UNet.java#L268, it should map into the C++ pretty straightforwardly. You need to duplicate the images to use guidance as @tianleiwu said.

from onnxruntime.

tianleiwu avatar tianleiwu commented on July 18, 2024 1

@Windsander, it is good that you find the cause.

In general, you can run your code and print every intermediate results, then run a baseline (like diffusers pipeline) using the same inputs. By comparing the intermediate results, you can easily find out which step causing parity issue.

from onnxruntime.

Windsander avatar Windsander commented on July 18, 2024

step-1 initial latent(with scheduler scaled):

output-s1-scaled

after UNet inference cond_latent:

output-s1-p

after UNet inference uncond_latent:

output-s1-n

from onnxruntime.

Windsander avatar Windsander commented on July 18, 2024

and most important, I can't just merge two files into one, by using this script below:

import onnx
from onnx.external_data_helper import load_external_data_for_model

onnx_model = onnx.load("/Volumes/AL-Data-W04/WorkingSpace/Self-Storage-Local/ML_study_demo/model/convert/model.onnx", load_external_data=False)
load_external_data_for_model(onnx_model, "/Volumes/AL-Data-W04/WorkingSpace/Self-Storage-Local/ML_study_demo/model/convert/")
# Then the onnx_model has loaded the external data from the specific directory

onnx.save(onnx_model, "/Volumes/AL-Data-W04/WorkingSpace/Self-Storage-Local/ML_study_demo/model/converted/model.onnx", save_as_external_data=False)

cause of :

    raise ValueError(
ValueError: The proto size is larger than the 2 GB limit. Please use save_as_external_data to save tensors separately from the model file.

which been told by official:
image

from onnxruntime.

Windsander avatar Windsander commented on July 18, 2024

thx! I'll trying it shortly! : )

from onnxruntime.

Windsander avatar Windsander commented on July 18, 2024

I have changed txt_embeddings_ to tidx_encoded + uncond_encoded, but the result is still not quite correct (although it seems to be somewhat better than before). :(

The code for the inference part of my UNet is as follows:

Tensor UNet::inference(
    const Tensor &txt_embeddings_,
    const Tensor &encoded_img_
) {
    int w_ = int(sd_unet_config.sd_input_width);
    int h_ = int(sd_unet_config.sd_input_height);
    int c_ = int(sd_unet_config.sd_input_channel);
    const bool need_guidance_ = (sd_unet_config.sd_scale_guidance > 1);

    TensorShape latent_shape_{1, c_, h_, w_};
    std::vector<float> latent_empty_(c_ * h_ * w_, 0.0f);
    Tensor latents_ = (TensorHelper::have_data(encoded_img_)) ?
                      TensorHelper::clone<float>(encoded_img_, latent_shape_) :
                      TensorHelper::create(latent_shape_, latent_empty_);
    Tensor init_mask_ = sd_scheduler_p->mask(latent_shape_);
    latents_ = TensorHelper::add(latents_, init_mask_, latent_shape_);

    for (int i = 0; i < sd_unet_config.sd_inference_steps; ++i) {
        Tensor model_latent_ = (need_guidance_) ?
                               sd_scheduler_p->scale(TensorHelper::duplicate<float>(latents_), i) :
                               sd_scheduler_p->scale(latents_, i);
        Tensor timestep_ = sd_scheduler_p->time(i);
        //TensorHelper::print_tensor_data<int64_t>(timestep_,  "timestep_" + std::to_string(i));
        //TensorHelper::print_tensor_data<float>(model_latent_,  "model_latent_" + std::to_string(i));

        // do positive N_pos_embed_num times
        std::vector<Tensor> input_tensors;
        input_tensors.emplace_back(TensorHelper::clone<float>(model_latent_));
        input_tensors.emplace_back(TensorHelper::clone<int64_t>(timestep_));
        input_tensors.emplace_back(TensorHelper::clone<float>(txt_embeddings_));
        std::vector<Tensor> output_tensors;
        generate_output(output_tensors);
        execute(input_tensors, output_tensors);

        // Split results
        std::vector<Tensor> output_splits = TensorHelper::split(output_tensors[0], latent_shape_);
        Tensor pred_normal_ = std::move(output_splits[0]);
        Tensor pred_uncond_ = std::move(output_splits[1]);

        // Merge predictions
        float merge_factor_ = sd_unet_config.sd_scale_guidance;
        Tensor guided_pred_ = (
            need_guidance_ ?
            TensorHelper::guidance(pred_normal_, pred_uncond_, merge_factor_) :
            TensorHelper::clone<float>(pred_normal_, latent_shape_)
        );

        // Dnoise & Step
        latents_ = sd_scheduler_p->step(latents_, guided_pred_, i);

        CommonHelper::print_progress_bar(float(i + 1) / float(sd_unet_config.sd_inference_steps));
    }

    return latents_;
}

The printed image comes from the data of pred_normal_ in the code.

output-s1-n

Seems still need a favor guys. 0x0

from onnxruntime.

Windsander avatar Windsander commented on July 18, 2024

I found that the reason UNet generated such results was because I did not properly tokenize the parsed vocabulary in the Tokenizer. I'm not sure if my judgment is accurate. emm

from onnxruntime.

Windsander avatar Windsander commented on July 18, 2024

@tianleiwu Finally! I found out why make this UNet output messy!

The reasons for generating such UNet output are twofold:

  1. During the token processing, the segmentation of idx using a self-implemented Word-Piece tokenizer resulted in dictionary index deviations, causing the token_idx input to CLIP to be imprecise.
  2. The encapsulation of the timestep Tensor experienced a long(999)->float(1.39989717E-42 fail)->long(0) conversion, leading to neuron death in the time_embedding after being input into UNet.

To identify this issue, I wrote a simplified Python script that is completely consistent with the C++ implementation/encapsulation. Under the same script input, I found that the step-1 result in Python was correct.
After step-by-step validation and single-layer debugging of the onnx model, I finally discovered that neither the CLIP, UNet, VAE models themselves nor the scheduler or tokenizer had any issues.
The cause was only found when I started performing one-to-one validation by passing Tensors between models.

Therefore, I am documenting the cause of the issue here, hoping that other colleagues can avoid making the same foolish mistake as I did. (0x0 )(0x0 )

and now, the step-1 prediction noised_latent is lovely correct!:
output-s1-n

Great Thanks for your inspiration! @Craigacp @tianleiwu 🌹

from onnxruntime.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.