Commit 37a9fb00 authored by mac's avatar mac

C

parent 2b8bff84
...@@ -87,7 +87,9 @@ class Generator(nn.Module): ...@@ -87,7 +87,9 @@ class Generator(nn.Module):
nn.Tanh() nn.Tanh()
) )
def forward(self, input, align_corners=True): def forward(self, inputTensor, align_corners=True):
input = inputTensor.unsqueeze(0)*2-1
out = self.block_a(input) out = self.block_a(input)
half_size = out.size()[-2:] half_size = out.size()[-2:]
out = self.block_b(out) out = self.block_b(out)
...@@ -106,5 +108,6 @@ class Generator(nn.Module): ...@@ -106,5 +108,6 @@ class Generator(nn.Module):
out = self.block_e(out) out = self.block_e(out)
out = self.out_layer(out) out = self.out_layer(out)
out = out.squeeze(0).clip(-1, 1) * 0.5 + 0.5
return out return out
\ No newline at end of file
...@@ -44,9 +44,9 @@ def test(args): ...@@ -44,9 +44,9 @@ def test(args):
image = load_image(os.path.join(args.input_dir, image_name), args.x32) image = load_image(os.path.join(args.input_dir, image_name), args.x32)
with torch.no_grad(): with torch.no_grad():
image = to_tensor(image).unsqueeze(0) * 2 - 1 image = to_tensor(image) #.unsqueeze(0) * 2 - 1
out = net(image.to(device), args.upsample_align).cpu() out = net(image.to(device), args.upsample_align).cpu()
out = out.squeeze(0).clip(-1, 1) * 0.5 + 0.5 #out = out.squeeze(0).clip(-1, 1) * 0.5 + 0.5
out = to_pil_image(out) out = to_pil_image(out)
out.save(os.path.join(args.output_dir, image_name)) out.save(os.path.join(args.output_dir, image_name))
......
...@@ -108,11 +108,12 @@ def img2Comix(args,net,device,image): ...@@ -108,11 +108,12 @@ def img2Comix(args,net,device,image):
#print('>>img_2_x32 size=', image.shape) #print('>>img_2_x32 size=', image.shape)
with torch.no_grad(): with torch.no_grad():
#print('----------torch.no_grad() handle') #print('----------torch.no_grad() handle')
#变成tensor数据后,形状变为CxHxW #变成tensor数据后,形状变为CxHxW 移到model内部
image = to_tensor(image).unsqueeze(0) * 2 - 1 #image = to_tensor(image).unsqueeze(0) * 2 - 1
#print('upsample_align=',args.upsample_align) #print('upsample_align=',args.upsample_align)
out = net(image.to(device), args.upsample_align).cpu() out = net(image.to(device), args.upsample_align).cpu()
out = out.squeeze(0).clip(-1, 1) * 0.5 + 0.5 #移到模型内部
#out = out.squeeze(0).clip(-1, 1) * 0.5 + 0.5
# out = to_pil_image(out) # out = to_pil_image(out)
out = (out).numpy() out = (out).numpy()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment