Git Product home page Git Product logo

Comments (6)

WSJUSA avatar WSJUSA commented on June 13, 2024 6

init_img which functions as the color style correction guide to colorfix, is now being passed to colorfix with 4 channels. This appears to be due to something having changed in how p StableDiffusionProcessingImg2Img prepares the source image.

as a possible fix is check all in bound tensors to colorfix and set their channel to 3

keep in mind, I have no idea what I am really doing, but this seems to work, added to the top of colorfix.py:

def channel_four_to_three(image: Tensor):
    # if tensor has 4 channels reduce to 3
    if image.shape[1] > 3:
        image = image[:, :3, :, :]
    return image

def adain_color_fix(target: Image, source: Image):
    # Convert images to tensors
    to_tensor = ToTensor()
    target_tensor = channel_four_to_three(to_tensor(target).unsqueeze(0))
    source_tensor = channel_four_to_three(to_tensor(source).unsqueeze(0))

    # Apply adaptive instance normalization
    result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)

    # Convert tensor back to image
    to_image = ToPILImage()
    result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))

    return result_image

def wavelet_color_fix(target: Image, source: Image):
    # Convert images to tensors
    to_tensor = ToTensor()
    target_tensor = channel_four_to_three(to_tensor(target).unsqueeze(0))
    source_tensor = channel_four_to_three(to_tensor(source).unsqueeze(0))

from sd-webui-stablesr.

FusionDraw9257 avatar FusionDraw9257 commented on June 13, 2024

I roll back to version 1.5.1...
SD>CMD>"git checkout tags/v1.5.1"
delete SD>venv
Re Strat

from sd-webui-stablesr.

WSJUSA avatar WSJUSA commented on June 13, 2024

I confirm using Automatic1111 ver 1.5.1 resolves the color issue.

I did this with StabilityMatrix so I did not have to roll back the 1.6 version or flop between git checkouts. You can install both in Packages side by side.

Small annoyance is you do have to copy the StableSR model webui_768v_139.ckpt to the extension in both copies of the SR plugin.

from sd-webui-stablesr.

snoopytl avatar snoopytl commented on June 13, 2024

init_img which functions as the color style correction guide to colorfix, is now being passed to colorfix with 4 channels. This appears to be due to something having changed in how p StableDiffusionProcessingImg2Img prepares the source image.

as a possible fix is check all in bound tensors to colorfix and set their channel to 3

keep in mind, I have no idea what I am really doing, but this seems to work, added to the top of colorfix.py:

def channel_four_to_three(image: Tensor):
    # if tensor has 4 channels reduce to 3
    if image.shape[1] > 3:
        image = image[:, :3, :, :]
    return image

def adain_color_fix(target: Image, source: Image):
    # Convert images to tensors
    to_tensor = ToTensor()
    target_tensor = channel_four_to_three(to_tensor(target).unsqueeze(0))
    source_tensor = channel_four_to_three(to_tensor(source).unsqueeze(0))

    # Apply adaptive instance normalization
    result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)

    # Convert tensor back to image
    to_image = ToPILImage()
    result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))

    return result_image

def wavelet_color_fix(target: Image, source: Image):
    # Convert images to tensors
    to_tensor = ToTensor()
    target_tensor = channel_four_to_three(to_tensor(target).unsqueeze(0))
    source_tensor = channel_four_to_three(to_tensor(source).unsqueeze(0))

借助这位大佬的代码,在我这还是会报错,于是继续修改了一下,修改后就正常了1.6和1.7都能用,代码如下,可以直接替换colorfix.py中的内容,
`
import torch
from PIL import Image
from torch import Tensor
from torch.nn import functional as F

from torchvision.transforms import ToTensor, ToPILImage

def channel_four_to_three(image: Tensor):
# if tensor has 4 channels reduce to 3
if image.shape[1] > 3:
image = image[:, :3, :, :]
return image

def adain_color_fix(target: Image, source: Image):
# Convert images to tensors
to_tensor = ToTensor()
target_tensor = channel_four_to_three(to_tensor(target).unsqueeze(0))
source_tensor = channel_four_to_three(to_tensor(source).unsqueeze(0))

# Apply adaptive instance normalization
result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)

# Convert tensor back to image
to_image = ToPILImage()
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))

return result_image

def wavelet_color_fix(target: Image, source: Image):
# Convert images to tensors
to_tensor = ToTensor()
target_tensor = channel_four_to_three(to_tensor(target).unsqueeze(0))
source_tensor = channel_four_to_three(to_tensor(source).unsqueeze(0))
#这是加的
# Apply wavelet reconstruction
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)

# Convert tensor back to image
to_image = ToPILImage()
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))

return result_image

#加的结束

#def adain_color_fix(target: Image, source: Image):

# Convert images to tensors

to_tensor = ToTensor()

target_tensor = to_tensor(target).unsqueeze(0)

source_tensor = to_tensor(source).unsqueeze(0)

# Apply adaptive instance normalization

result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)

# Convert tensor back to image

to_image = ToPILImage()

result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))

return result_image

#def wavelet_color_fix(target: Image, source: Image):

# Convert images to tensors

to_tensor = ToTensor()

target_tensor = to_tensor(target).unsqueeze(0)

source_tensor = to_tensor(source).unsqueeze(0)

# Apply wavelet reconstruction

result_tensor = wavelet_reconstruction(target_tensor, source_tensor)

# Convert tensor back to image

to_image = ToPILImage()

result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))

return result_image

def calc_mean_std(feat: Tensor, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, 'The input feature should be 4D tensor.'
b, c = size[:2]
feat_var = feat.view(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(b, c, 1, 1)
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
return feat_mean, feat_std

def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)

def wavelet_blur(image: Tensor, radius: int):
"""
Apply wavelet blur to the input tensor.
"""
# input shape: (1, 3, H, W)
# convolution kernel
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
# add channel dimensions to the kernel to make it a 4D tensor
kernel = kernel[None, None]
# repeat the kernel across all input channels
kernel = kernel.repeat(3, 1, 1, 1)
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
# apply convolution
output = F.conv2d(image, kernel, groups=3, dilation=radius)
return output

def wavelet_decomposition(image: Tensor, levels=5):
"""
Apply wavelet decomposition to the input tensor.
This function only returns the low frequency & the high frequency.
"""
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur(image, radius)
high_freq += (image - low_freq)
image = low_freq

return high_freq, low_freq

def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
"""
Apply wavelet decomposition, so that the content will have the same color as the style.
"""
# calculate the wavelet decomposition of the content feature
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq
# calculate the wavelet decomposition of the style feature
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq
# reconstruct the content feature with the style's high frequency
return content_high_freq + style_low_freq

`

from sd-webui-stablesr.

skywalker0113 avatar skywalker0113 commented on June 13, 2024

init_img which functions as the color style correction guide to colorfix, is now being passed to colorfix with 4 channels. This appears to be due to something having changed in how p StableDiffusionProcessingImg2Img prepares the source image.
as a possible fix is check all in bound tensors to colorfix and set their channel to 3
keep in mind, I have no idea what I am really doing, but this seems to work, added to the top of colorfix.py:

def channel_four_to_three(image: Tensor):
    # if tensor has 4 channels reduce to 3
    if image.shape[1] > 3:
        image = image[:, :3, :, :]
    return image

def adain_color_fix(target: Image, source: Image):
    # Convert images to tensors
    to_tensor = ToTensor()
    target_tensor = channel_four_to_three(to_tensor(target).unsqueeze(0))
    source_tensor = channel_four_to_three(to_tensor(source).unsqueeze(0))

    # Apply adaptive instance normalization
    result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)

    # Convert tensor back to image
    to_image = ToPILImage()
    result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))

    return result_image

def wavelet_color_fix(target: Image, source: Image):
    # Convert images to tensors
    to_tensor = ToTensor()
    target_tensor = channel_four_to_three(to_tensor(target).unsqueeze(0))
    source_tensor = channel_four_to_three(to_tensor(source).unsqueeze(0))

借助这位大佬的代码,在我这还是会报错,于是继续修改了一下,修改后就正常了1.6和1.7都能用,代码如下,可以直接替换colorfix.py中的内容, ` import torch from PIL import Image from torch import Tensor from torch.nn import functional as F

from torchvision.transforms import ToTensor, ToPILImage

def channel_four_to_three(image: Tensor): # if tensor has 4 channels reduce to 3 if image.shape[1] > 3: image = image[:, :3, :, :] return image

def adain_color_fix(target: Image, source: Image): # Convert images to tensors to_tensor = ToTensor() target_tensor = channel_four_to_three(to_tensor(target).unsqueeze(0)) source_tensor = channel_four_to_three(to_tensor(source).unsqueeze(0))

# Apply adaptive instance normalization
result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)

# Convert tensor back to image
to_image = ToPILImage()
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))

return result_image

def wavelet_color_fix(target: Image, source: Image): # Convert images to tensors to_tensor = ToTensor() target_tensor = channel_four_to_three(to_tensor(target).unsqueeze(0)) source_tensor = channel_four_to_three(to_tensor(source).unsqueeze(0)) #这是加的 # Apply wavelet reconstruction result_tensor = wavelet_reconstruction(target_tensor, source_tensor)

# Convert tensor back to image
to_image = ToPILImage()
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))

return result_image

#加的结束

#def adain_color_fix(target: Image, source: Image):

# Convert images to tensors

to_tensor = ToTensor()

target_tensor = to_tensor(target).unsqueeze(0)

source_tensor = to_tensor(source).unsqueeze(0)

# Apply adaptive instance normalization

result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)

# Convert tensor back to image

to_image = ToPILImage()

result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))

return result_image

#def wavelet_color_fix(target: Image, source: Image):

# Convert images to tensors

to_tensor = ToTensor()

target_tensor = to_tensor(target).unsqueeze(0)

source_tensor = to_tensor(source).unsqueeze(0)

# Apply wavelet reconstruction

result_tensor = wavelet_reconstruction(target_tensor, source_tensor)

# Convert tensor back to image

to_image = ToPILImage()

result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))

return result_image

def calc_mean_std(feat: Tensor, eps=1e-5): """Calculate mean and std for adaptive_instance_normalization. Args: feat (Tensor): 4D tensor. eps (float): A small value added to the variance to avoid divide-by-zero. Default: 1e-5. """ size = feat.size() assert len(size) == 4, 'The input feature should be 4D tensor.' b, c = size[:2] feat_var = feat.view(b, c, -1).var(dim=2) + eps feat_std = feat_var.sqrt().view(b, c, 1, 1) feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) return feat_mean, feat_std

def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): """Adaptive instance normalization. Adjust the reference features to have the similar color and illuminations as those in the degradate features. Args: content_feat (Tensor): The reference feature. style_feat (Tensor): The degradate features. """ size = content_feat.size() style_mean, style_std = calc_mean_std(style_feat) content_mean, content_std = calc_mean_std(content_feat) normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) return normalized_feat * style_std.expand(size) + style_mean.expand(size)

def wavelet_blur(image: Tensor, radius: int): """ Apply wavelet blur to the input tensor. """ # input shape: (1, 3, H, W) # convolution kernel kernel_vals = [ [0.0625, 0.125, 0.0625], [0.125, 0.25, 0.125], [0.0625, 0.125, 0.0625], ] kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) # add channel dimensions to the kernel to make it a 4D tensor kernel = kernel[None, None] # repeat the kernel across all input channels kernel = kernel.repeat(3, 1, 1, 1) image = F.pad(image, (radius, radius, radius, radius), mode='replicate') # apply convolution output = F.conv2d(image, kernel, groups=3, dilation=radius) return output

def wavelet_decomposition(image: Tensor, levels=5): """ Apply wavelet decomposition to the input tensor. This function only returns the low frequency & the high frequency. """ high_freq = torch.zeros_like(image) for i in range(levels): radius = 2 ** i low_freq = wavelet_blur(image, radius) high_freq += (image - low_freq) image = low_freq

return high_freq, low_freq

def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): """ Apply wavelet decomposition, so that the content will have the same color as the style. """ # calculate the wavelet decomposition of the content feature content_high_freq, content_low_freq = wavelet_decomposition(content_feat) del content_low_freq # calculate the wavelet decomposition of the style feature style_high_freq, style_low_freq = wavelet_decomposition(style_feat) del style_high_freq # reconstruct the content feature with the style's high frequency return content_high_freq + style_low_freq

`

亲测有效,就是格式乱七八糟
Personally tested it to be effective, but the format was messy and disorganized.

from sd-webui-stablesr.

yw-2020 avatar yw-2020 commented on June 13, 2024

#加的结束

init_img which functions as the color style correction guide to colorfix, is now being passed to colorfix with 4 channels. This appears to be due to something having changed in how p StableDiffusionProcessingImg2Img prepares the source image.
as a possible fix is check all in bound tensors to colorfix and set their channel to 3
keep in mind, I have no idea what I am really doing, but this seems to work, added to the top of colorfix.py:

def channel_four_to_three(image: Tensor):
    # if tensor has 4 channels reduce to 3
    if image.shape[1] > 3:
        image = image[:, :3, :, :]
    return image

def adain_color_fix(target: Image, source: Image):
    # Convert images to tensors
    to_tensor = ToTensor()
    target_tensor = channel_four_to_three(to_tensor(target).unsqueeze(0))
    source_tensor = channel_four_to_three(to_tensor(source).unsqueeze(0))

    # Apply adaptive instance normalization
    result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)

    # Convert tensor back to image
    to_image = ToPILImage()
    result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))

    return result_image

def wavelet_color_fix(target: Image, source: Image):
    # Convert images to tensors
    to_tensor = ToTensor()
    target_tensor = channel_four_to_three(to_tensor(target).unsqueeze(0))
    source_tensor = channel_four_to_three(to_tensor(source).unsqueeze(0))

借助这位大佬的代码,在我这还是会报错,于是继续修改了一下,修改后就正常了1.6和1.7都能用,代码如下,可以直接替换colorfix.py中的内容, ` import torch from PIL import Image from torch import Tensor from torch.nn import functional as F
from torchvision.transforms import ToTensor, ToPILImage
def channel_four_to_three(image: Tensor): # if tensor has 4 channels reduce to 3 if image.shape[1] > 3: image = image[:, :3, :, :] return image
def adain_color_fix(target: Image, source: Image): # Convert images to tensors to_tensor = ToTensor() target_tensor = channel_four_to_three(to_tensor(target).unsqueeze(0)) source_tensor = channel_four_to_three(to_tensor(source).unsqueeze(0))

# Apply adaptive instance normalization
result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)

# Convert tensor back to image
to_image = ToPILImage()
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))

return result_image

def wavelet_color_fix(target: Image, source: Image): # Convert images to tensors to_tensor = ToTensor() target_tensor = channel_four_to_three(to_tensor(target).unsqueeze(0)) source_tensor = channel_four_to_three(to_tensor(source).unsqueeze(0)) #这是加的 # Apply wavelet reconstruction result_tensor = wavelet_reconstruction(target_tensor, source_tensor)

# Convert tensor back to image
to_image = ToPILImage()
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))

return result_image

#加的结束
#def adain_color_fix(target: Image, source: Image):

# Convert images to tensors

to_tensor = ToTensor()

target_tensor = to_tensor(target).unsqueeze(0)

source_tensor = to_tensor(source).unsqueeze(0)

# Apply adaptive instance normalization

result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)

# Convert tensor back to image

to_image = ToPILImage()

result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))

return result_image

#def wavelet_color_fix(target: Image, source: Image):

# Convert images to tensors

to_tensor = ToTensor()

target_tensor = to_tensor(target).unsqueeze(0)

source_tensor = to_tensor(source).unsqueeze(0)

# Apply wavelet reconstruction

result_tensor = wavelet_reconstruction(target_tensor, source_tensor)

# Convert tensor back to image

to_image = ToPILImage()

result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))

return result_image

def calc_mean_std(feat: Tensor, eps=1e-5): """Calculate mean and std for adaptive_instance_normalization. Args: feat (Tensor): 4D tensor. eps (float): A small value added to the variance to avoid divide-by-zero. Default: 1e-5. """ size = feat.size() assert len(size) == 4, 'The input feature should be 4D tensor.' b, c = size[:2] feat_var = feat.view(b, c, -1).var(dim=2) + eps feat_std = feat_var.sqrt().view(b, c, 1, 1) feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) return feat_mean, feat_std
def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): """Adaptive instance normalization. Adjust the reference features to have the similar color and illuminations as those in the degradate features. Args: content_feat (Tensor): The reference feature. style_feat (Tensor): The degradate features. """ size = content_feat.size() style_mean, style_std = calc_mean_std(style_feat) content_mean, content_std = calc_mean_std(content_feat) normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) return normalized_feat * style_std.expand(size) + style_mean.expand(size)
def wavelet_blur(image: Tensor, radius: int): """ Apply wavelet blur to the input tensor. """ # input shape: (1, 3, H, W) # convolution kernel kernel_vals = [ [0.0625, 0.125, 0.0625], [0.125, 0.25, 0.125], [0.0625, 0.125, 0.0625], ] kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) # add channel dimensions to the kernel to make it a 4D tensor kernel = kernel[None, None] # repeat the kernel across all input channels kernel = kernel.repeat(3, 1, 1, 1) image = F.pad(image, (radius, radius, radius, radius), mode='replicate') # apply convolution output = F.conv2d(image, kernel, groups=3, dilation=radius) return output
def wavelet_decomposition(image: Tensor, levels=5): """ Apply wavelet decomposition to the input tensor. This function only returns the low frequency & the high frequency. """ high_freq = torch.zeros_like(image) for i in range(levels): radius = 2 ** i low_freq = wavelet_blur(image, radius) high_freq += (image - low_freq) image = low_freq

return high_freq, low_freq

def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): """ Apply wavelet decomposition, so that the content will have the same color as the style. """ # calculate the wavelet decomposition of the content feature content_high_freq, content_low_freq = wavelet_decomposition(content_feat) del content_low_freq # calculate the wavelet decomposition of the style feature style_high_freq, style_low_freq = wavelet_decomposition(style_feat) del style_high_freq # reconstruct the content feature with the style's high frequency return content_high_freq + style_low_freq
`

亲测有效,就是格式乱七八糟 Personally tested it to be effective, but the format was messy and disorganized.
啥意思,什么格式乱七八糟,不就是将修改的函数替换了就行吗?

from sd-webui-stablesr.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.