Open markdjwilliams opened 4 years ago
Yeah, having a reproducible example would be helpful. Within a single rendering call we assume all the pytorch tensors live in the same GPU. Maybe there are some bugs related to that.
It turned out to be easier than I imaged to create a repro case. The scene is a single sphere, with the camera close enough to fill the entire image. There's a while loop which keeps rendering images in batches of 16 (I used 2 GPUs) until it spots one with an incomplete alpha channel. Within the archive is also the image which triggered the loop to terminate.
Thanks again.
You want to call pyredner.set_device()
inside your rendering module so that redner can use the right device. I modified your code like this and it works for me in 0.4.11.
import torch
import redner
import pyredner
pyredner.set_use_gpu( True )
class BatchRenderFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, seed, *args):
batch_dims = args[0]
args_old_format = args[1:]
chunk_len = int(len(args_old_format)/batch_dims)
h, w = args_old_format[12]
result = torch.zeros(\
batch_dims, h, w, 6, device = pyredner.get_device(), requires_grad=True)
for k in range(0, batch_dims):
sub_args = args_old_format[k*chunk_len:(k+1)*chunk_len]
result[k, :, :, :] = pyredner.RenderFunction.forward(ctx, seed, *sub_args)
return result
@staticmethod
def backward(ctx, grad_img):
#None gradient for seed and batch_dims
ret_list = (None, None,)
batch_dims = grad_img.shape[0]
for k in range(0, batch_dims):
#[1:] cuz original backward function returns None grad for seed input, but we manage that ourselves
ret_list = ret_list + pyredner.RenderFunction.backward(ctx, grad_img[k,:,:,:])[1:]
return ret_list
class RenderModule( torch.nn.Module ):
def __init__( self ):
super().__init__()
def forward( self, x ):
batch_size = x.shape[ 0 ]
cam_to_world = pyredner.transform.gen_look_at_matrix(
torch.tensor([0.0, 0.0, 2.0]),
torch.tensor([0.0, 0.0, 0.0]),
torch.tensor([0.0, 1.0, 0.0])
)
camera = pyredner.Camera(
cam_to_world=cam_to_world,
fov = torch.tensor( [ 18.0 ] ),
clip_near = 1e-2,
resolution = ( 256, 256, ),
fisheye = False,
)
vertices = []
indices = None
uvs = None
normals = None
for b in range( batch_size ):
_vertices, _indices, _uvs, _normals = pyredner.generate_sphere(64, 128)
vertices.append( _vertices.unsqueeze( 0 ) )
if indices is None: indices = _indices
if uvs is None: uvs = _uvs
if normals is None: normals = _normals
vertices = torch.cat( vertices, dim=0)
pyredner.set_device(vertices.device)
args = [ batch_size ]
materials = [ pyredner.Material(diffuse_reflectance=torch.tensor([0.5,0.5,0.5]))]
for b in range( batch_size ):
shapes = [
pyredner.Shape(
vertices = vertices[b],
indices = indices,
uvs = uvs,
normals = normals,
material_id = 0
)
]
scene = pyredner.Scene(camera, shapes, materials, area_lights = [], envmap = None)
args += pyredner.RenderFunction.serialize_scene(
scene = scene,
num_samples = 1,
max_bounces = 0,
channels = [
redner.channels.alpha,
redner.channels.diffuse_reflectance,
redner.channels.uv
]
)
return BatchRenderFunction.apply( 1, *args)
model = RenderModule()
model = torch.nn.DataParallel( model )
batch_size = 16
while True:
x = torch.randn( batch_size, 1 ).cuda()
images = model( x )
for i in range( batch_size ):
img = images[i, :, :, 1:4]
pyredner.imwrite( img.cpu(), 'debug_redner_%02d.png'%i)
if ( ( images[ i, :, :, 0 ] < 0.1 ).any() ):
raise RuntimeError( i )
Another minor bug in the provided code is that redner outputs HWC arrays, while the code assumes CHW.
We might want to rethink how the device assignment works in redner...
Thank you, that appears to help in the repro case (and reveals some issues with how I was handling devices) but in the context of my original example there's now a race condition of some sort, potentially due to the fact that pyredner.set_device/get_device manipulate global state.
Here's an update example which when run will likely fire one of the asserts, or terminate after thrust::system::system_error
is thrown. The original RuntimeError, however, does appear to be fixed.
Note that I also updated the autograd function so that device in question is specified explicitly within args
.
I'll also update to 0.4.11
import torch
import redner
import pyredner
import time
import numpy as np
pyredner.set_use_gpu( True )
class BatchRenderFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, seed, *args):
batch_dims, device = args[0], args[ 1]
args_old_format = args[2:]
chunk_len = int(len(args_old_format)/batch_dims)
h, w = args_old_format[11]
result = torch.zeros(\
batch_dims, h, w, 6, device = device, requires_grad=True)
for k in range(0, batch_dims):
sub_args = args_old_format[k*chunk_len:(k+1)*chunk_len]
result[k, :, :, :] = pyredner.RenderFunction.forward(ctx, seed, *sub_args)
return result
@staticmethod
def backward(ctx, grad_img):
#None gradient for seed and batch_dims
ret_list = (None, None,None,)
batch_dims = grad_img.shape[0]
for k in range(0, batch_dims):
#[1:] cuz original backward function returns None grad for seed input, but we manage that ourselves
ret_list = ret_list + pyredner.RenderFunction.backward(ctx, grad_img[k,:,:,:])[1:]
return ret_list
class RenderModule( torch.nn.Module ):
def __init__( self ):
super().__init__()
self.register_parameter(
"scale",
torch.nn.Parameter(
torch.tensor( 1.0 )
)
)
self.register_parameter(
"offset",
torch.nn.Parameter(
torch.tensor( 0.0 )
)
)
def forward( self, x ):
assert( x.device == self.scale.device )
assert( x.device == self.offset.device )
batch_size = x.shape[ 0 ]
cam_to_world = pyredner.transform.gen_look_at_matrix(
torch.tensor([0.0, 0.0, 2.0]),
torch.tensor([0.0, 0.0, 0.0]),
torch.tensor([0.0, 1.0, 0.0])
)
camera = pyredner.Camera(
cam_to_world=cam_to_world,
fov = torch.tensor( [ 18.0 ] ),
clip_near = 1e-2,
resolution = ( 256, 256, ),
fisheye = False,
)
vertices = []
indices = None
uvs = None
normals = None
for b in range( batch_size ):
_vertices, _indices, _uvs, _normals = pyredner.generate_sphere(64, 128)
vertices.append( _vertices.unsqueeze( 0 ) * self.scale )
if indices is None: indices = _indices
if uvs is None: uvs = _uvs
if normals is None: normals = _normals
vertices = torch.cat( vertices, dim=0)
pyredner.set_device( vertices.device )
args = [ batch_size, vertices.device ]
materials = [ pyredner.Material(diffuse_reflectance=torch.tensor([0.5,0.5,0.5]))]
assert( pyredner.get_device() == vertices.device )
for b in range( batch_size ):
assert( pyredner.get_device() == vertices.device )
shapes = [
pyredner.Shape(
vertices = vertices[b],
indices = indices,
uvs = uvs,
normals = normals,
material_id = 0
)
]
assert( pyredner.get_device() == vertices.device )
time.sleep( np.random.uniform(low=0.01, high=0.05) )
assert( pyredner.get_device() == vertices.device )
scene = pyredner.Scene(camera, shapes, materials, area_lights = [], envmap = None)
assert( pyredner.get_device() == vertices.device )
args += pyredner.RenderFunction.serialize_scene(
scene = scene,
num_samples = 1,
max_bounces = 0,
channels = [
redner.channels.alpha,
redner.channels.diffuse_reflectance,
redner.channels.uv
]
)
assert( pyredner.get_device() == vertices.device )
result = BatchRenderFunction.apply( 1, *args) + self.offset
assert( result.device == vertices.device )
return result
model = RenderModule()
model.cuda()
model = torch.nn.DataParallel( model )
batch_size = 16
optim = torch.optim.SGD( model.parameters(), lr=0.01)
while True:
x = torch.randn( batch_size, 1 ).cuda()
images = model( x )
loss = ( images - 1.0 ).mean()
loss.backward()
optim.step()
for i in range( batch_size ):
pyredner.imwrite( images[ i ].cpu(), 'debug_redner_%02d.exr'%i)
if ( ( images[ i, :, :, 0 ] < 0.1 ).any() ):
raise RuntimeError( i )
Should be fixed in 0.4.14. I added an argument to serialize_scene
such that you can pass in a torch.device
. Modifying your code into the following in 0.4.14 works for me.
import torch
import redner
import pyredner
import time
import numpy as np
pyredner.set_use_gpu( True )
class BatchRenderFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, seed, *args):
batch_dims, device = args[0], args[ 1]
args_old_format = args[2:]
chunk_len = int(len(args_old_format)/batch_dims)
h, w = args_old_format[12]
result = torch.zeros(\
batch_dims, h, w, 6, device = device, requires_grad=True)
for k in range(0, batch_dims):
sub_args = args_old_format[k*chunk_len:(k+1)*chunk_len]
result[k, :, :, :] = pyredner.RenderFunction.forward(ctx, seed, *sub_args)
return result
@staticmethod
def backward(ctx, grad_img):
#None gradient for seed and batch_dims
ret_list = (None, None,None,)
batch_dims = grad_img.shape[0]
for k in range(0, batch_dims):
#[1:] cuz original backward function returns None grad for seed input, but we manage that ourselves
ret_list = ret_list + pyredner.RenderFunction.backward(ctx, grad_img[k,:,:,:])[1:]
return ret_list
class RenderModule( torch.nn.Module ):
def __init__( self ):
super().__init__()
self.register_parameter(
"scale",
torch.nn.Parameter(
torch.tensor( 1.0 )
)
)
self.register_parameter(
"offset",
torch.nn.Parameter(
torch.tensor( 0.0 )
)
)
def forward( self, x ):
assert( x.device == self.scale.device )
assert( x.device == self.offset.device )
batch_size = x.shape[ 0 ]
cam_to_world = pyredner.transform.gen_look_at_matrix(
torch.tensor([0.0, 0.0, 2.0]),
torch.tensor([0.0, 0.0, 0.0]),
torch.tensor([0.0, 1.0, 0.0])
)
camera = pyredner.Camera(
cam_to_world=cam_to_world,
fov = torch.tensor( [ 18.0 ] ),
clip_near = 1e-2,
resolution = ( 256, 256, ),
fisheye = False,
)
vertices = []
indices = None
uvs = None
normals = None
for b in range( batch_size ):
_vertices, _indices, _uvs, _normals = pyredner.generate_sphere(64, 128)
vertices.append( _vertices.unsqueeze( 0 ) * self.scale )
if indices is None: indices = _indices
if uvs is None: uvs = _uvs
if normals is None: normals = _normals
vertices = torch.cat( vertices, dim=0)
args = [ batch_size, vertices.device ]
materials = [ pyredner.Material(diffuse_reflectance=torch.tensor([0.5,0.5,0.5]))]
for b in range( batch_size ):
shapes = [
pyredner.Shape(
vertices = vertices[b],
indices = indices,
uvs = uvs,
normals = normals,
material_id = 0
)
]
time.sleep( np.random.uniform(low=0.01, high=0.05) )
scene = pyredner.Scene(camera, shapes, materials, area_lights = [], envmap = None)
args += pyredner.RenderFunction.serialize_scene(
scene = scene,
num_samples = 1,
max_bounces = 0,
channels = [
redner.channels.alpha,
redner.channels.diffuse_reflectance,
redner.channels.uv
],
device = vertices.device
)
result = BatchRenderFunction.apply( 1, *args) + self.offset
assert( result.device == vertices.device )
return result
model = RenderModule()
model.cuda()
model = torch.nn.DataParallel( model )
batch_size = 16
optim = torch.optim.SGD( model.parameters(), lr=0.01)
while True:
x = torch.randn( batch_size, 1 ).cuda()
images = model( x )
loss = ( images - 1.0 ).mean()
loss.backward()
optim.step()
for i in range( batch_size ):
pyredner.imwrite( images[i, :, :, 0].cpu(), 'debug_redner_%02d.exr'%i)
if ( ( images[ i, :, :, 0 ] < 0.1 ).any() ):
raise RuntimeError( i )
You may not need to store the device into the argument list anymore.
Thank you so much, I'll update to 0.4.14 and try again!
Thank you this is working well in 0.4.14 - I just had to remove the bit in pyredner/image.py
which attempts to download freeimage (I don't have unrestricted web access within my organization).
I was not only able to run the fixed example, but my original use case is now running stably on 4 GPUs using DataParallel. I'll see if I can get DistributedDataParallel to work, too, as this should boost performance even further.
Thank you once again.
First of all thank you for your wonderful contribution to the community.
I'm training a neural network which is tasked with predicting the transformation and deformation of a single template shape, which has approximately 50,000 vertices. I've noticed when using 2 GPUs via DataParallel that some incomplete scan-lines are present in the rendered images. Approximately 50% of the scan-lines are impacted. The alpha channel is black in these areas, as if the rays missed the geometry entirely. The other channels I'm rendering alongside the alpha channel (diffuse and uv) also show the same issue in the same pixels. Each render in my batch shows random variation in this horizontally-aligned pattern, and occasionally some samples in the batch appear to be entirely unaffected.
When I revert training back to a single GPU the problem goes away, and the rendered images look exactly as I would expect them to appear.
Unfortunately I'm not able to provide any code or images at this time, but will see if I can encourage the same behaviour to appear using one of the provided tests/examples.
I'm using redner-0.4.5, optix-6.5.0, cuda-10.0.130, and pytorch-1.4.0 through pytorch_lighting-0.7.2. I'm compiling all dependencies myself under gcc-6.4.0, so I'm easily able to tweak the code if there are things which are worth testing out. However, I'm not too familiar with cuda itself as I mostly access GPU via pytorch or other high-level applications.
Thank you for your help.