Closed zhiheng-ldj closed 3 years ago
Hi,
I would check if you are using the same interpolation method when you compare both implementations. Also, you can check out test to see what level of accuracy you can expect. Please check ToCVMatrix
as well to learn how DALI and OCV understand the transformation matrix passed to the corresponding functions.
Hi, I would check if you are using the same interpolation method when you compare both implementations. Also, you can check out test to see what level of accuracy you can expect. Please check
ToCVMatrix
as well to learn how DALI and OCV understand the transformation matrix passed to the corresponding functions.
Hi,
Thanks for your advice. I have checked the interpolation (both linear) and set inverse_map=False
.
After that, I got:
DALI version:
tensor([[187, 188, 178],
[187, 187, 177],
[188, 188, 178],
[196, 196, 186],
[203, 203, 193],
[200, 200, 190],
[178, 176, 167],
[161, 158, 149],
[169, 164, 156],
[154, 149, 141],
[179, 174, 166],
[191, 186, 178],
[202, 193, 182],
[201, 194, 182],
[173, 167, 155],
[166, 162, 149],
[155, 152, 139],
[172, 170, 157],
[182, 182, 172],
[178, 178, 170],
[167, 168, 161],
[141, 143, 143],
[151, 153, 154],
[147, 149, 152],
[136, 132, 134],
[118, 114, 116],
[ 95, 91, 93],
[ 76, 73, 76],
[ 66, 64, 67],
[ 61, 60, 64],
[ 63, 64, 70],
[ 67, 68, 74],
[ 72, 73, 79],
[ 84, 87, 94],
[112, 115, 122],
[138, 141, 148],
[174, 185, 199],
[196, 206, 220],
[210, 220, 234],
[212, 221, 232],
[214, 222, 232],
[217, 223, 233],
[193, 196, 203],
[233, 235, 239],
[243, 245, 248],
[255, 255, 255],
[253, 253, 253],
[253, 253, 253],
[253, 252, 251],
[253, 252, 252],
[252, 250, 251],
[249, 246, 251],
[246, 243, 249],
[244, 241, 248],
[243, 240, 247],
[246, 243, 249],
[249, 246, 251],
[255, 254, 255],
[250, 249, 249],
[253, 252, 251],
[255, 255, 255],
[255, 255, 255],
[255, 255, 255],
[255, 255, 255],
[255, 255, 255],
[255, 255, 255],
[255, 255, 255],
[255, 255, 255],
[255, 255, 255],
[255, 255, 255],
[255, 255, 255],
[255, 255, 255],
[250, 250, 250],
[252, 252, 252],
[247, 247, 247],
[238, 238, 238],
[244, 244, 244],
[237, 237, 237],
[243, 243, 243],
[238, 238, 238],
[232, 232, 232],
[234, 234, 234],
[228, 228, 228],
[228, 228, 228],
[224, 224, 224],
[221, 221, 221],
[218, 218, 218],
[217, 217, 217],
[214, 214, 214],
[210, 210, 210],
[211, 211, 211],
[208, 208, 208],
[201, 201, 201],
[204, 204, 204],
[197, 197, 197],
[200, 200, 200]], device='cuda:0', dtype=torch.uint8)
cv2 version:
[[187 189 178]
[187 188 177]
[187 187 177]
[194 194 184]
[203 203 193]
[201 201 191]
[185 182 173]
[158 155 146]
[177 173 164]
[161 157 148]
[180 176 167]
[190 186 176]
[200 191 182]
[204 197 186]
[178 172 160]
[168 164 152]
[156 152 140]
[168 166 153]
[182 182 172]
[177 178 169]
[169 170 165]
[143 145 142]
[160 161 162]
[159 159 162]
[149 147 150]
[132 128 129]
[108 104 105]
[ 83 81 84]
[ 69 68 71]
[ 62 61 66]
[ 65 66 71]
[ 64 65 70]
[ 62 63 68]
[ 66 69 76]
[ 91 94 101]
[118 122 130]
[167 177 189]
[190 201 216]
[209 219 232]
[213 221 234]
[214 222 232]
[217 224 232]
[188 191 200]
[226 228 231]
[245 246 247]
[255 255 255]
[253 253 253]
[253 253 253]
[255 255 253]
[255 255 254]
[254 252 253]
[250 248 253]
[247 245 250]
[244 241 248]
[241 238 245]
[243 241 247]
[248 246 250]
[255 254 255]
[252 251 251]
[252 252 250]
[255 255 255]
[255 255 255]
[255 255 255]
[255 255 255]
[255 255 255]
[255 255 255]
[255 255 255]
[255 255 255]
[255 255 255]
[255 255 255]
[255 255 255]
[255 255 255]
[250 250 250]
[252 252 252]
[248 248 248]
[237 237 237]
[244 244 244]
[239 239 239]
[243 243 243]
[239 239 239]
[232 232 232]
[235 235 235]
[229 229 229]
[228 228 228]
[224 224 224]
[221 221 221]
[219 219 219]
[217 217 217]
[214 214 214]
[211 211 211]
[213 213 213]
[205 205 205]
[204 204 204]
[206 206 206]
[196 196 196]
[199 199 199]]
They are indeed more similar and should the difference be expected?
Thanks!
@zhiheng-ldj There are some differences between OpenCV and DALI and some of them are there on purpose. The biggest difference is the origin of image coordinate system. In OpenCV (0,0) is the center of the top-left pixel. In DALI, (0, 0) is the top-left corner of the top-left pixel. This is the system that's used by graphics libraries such as DirectX, OpenGL or Vulkan as well as by texture mapping units. We actually use OpenCV in our tests as reference and to make it work with origin at pixel corner, we apply:
def ToCVMatrix(matrix):
offset = np.matmul(matrix, np.array([[0.5], [0.5], [1]]))
result = matrix.copy()
result[0][2] = offset[0] - 0.5
result[1][2] = offset[1] - 0.5
return result
before feeding the matrix to OpenCV. Also, OpenCV's linear interpolation is approximate - it uses precalculated weights and has only 32 linear steps. Recently OpenCV intruduced something like LINEAR_EXACT or similar. DALI always uses exact floating point interpolation. The remaining differences are due to numerical discrepancies resulting from different order of operations, use of fused multiply add or slightly different rounding modes - these differ even within DALI between CPU and GPU.
@zhiheng-ldj There are some differences between OpenCV and DALI and some of them are there on purpose. The biggest difference is the origin of image coordinate system. In OpenCV (0,0) is the center of the top-left pixel. In DALI, (0, 0) is the top-left corner of the top-left pixel. This is the system that's used by graphics libraries such as DirectX, OpenGL or Vulkan as well as by texture mapping units. We actually use OpenCV in our tests as reference and to make it work with origin at pixel corner, we apply:
def ToCVMatrix(matrix): offset = np.matmul(matrix, np.array([[0.5], [0.5], [1]])) result = matrix.copy() result[0][2] = offset[0] - 0.5 result[1][2] = offset[1] - 0.5 return result
before feeding the matrix to OpenCV. Also, OpenCV's linear interpolation is approximate - it uses precalculated weights and has only 32 linear steps. Recently OpenCV intruduced something like LINEAR_EXACT or similar. DALI always uses exact floating point interpolation. The remaining differences are due to numerical discrepancies resulting from different order of operations, use of fused multiply add or slightly different rounding modes - these differ even within DALI between CPU and GPU.
Thanks for your very detailed explanation!
Btw, should matrix be in a 2x3 shape? From my understanding, we need to recalculate the matrix applied to ops.WarpAffine with using this ToCVMatrix
, which is 2x3?
Am I wrong?
Yes, a two-dimensionall affine transform matrix is 2x3 - that's what both WarpAffine and OpenCV accept. This adjustment function is DALI -> CV - to do the reverse, you need to adjust the matrix in the opposite way. If you're about to do the latter, please double-check or triple-check or more-times-check that the matrix used with OpenCV is indeed correct - I've seen it very many times that the matrix was ever-so-slightly wrong with OpenCV - this does not manifest itself when the matrix is relatively close to identity, but with large transforms, for example, rotating image by multiple of 90 degrees causes padding to be applied on one of the side and the opposite row or column being lost.
Anyway - where does your transform matrix come from? You can build one using operators in transforms
module. The example below creates a matrix with random rotation (-20 to 20 degrees) and shear matrix (-10 to 10 degrees in each X and Y). Assuming that your center coordinates are calculated correctly in corner-based coordinates, the output matrices will produce a mathematically correct result when used with DALI.
center = [224,224]
m = fn.transforms.rotation(angle=fn.random.uniform(range=(-20,20)), center=center)
m = fn.transforms.shear(m, angles=fn.random(range=(-10,10), shape=[2]), center=center)
If you happen to have matrices that prodcue correct result with OpenCV, you can use these operators to adjust the transform (from OpenCV to DALI) - but do that only if you already have the matrices in OpenCV coordinates.
m = # a DataNode with a transform in OpenCV coords
m = fn.transforms.translation(m, offset=[-0.5, -0.5], reverse_order=True) # apply translation before m
m = fn.transforms.translation(m, offset=[0.5, 0.5]) # apply translation on top of m
Anyway - where does your transform matrix come from? You can build one using operators in
transforms
module. The example below creates a matrix with random rotation (-20 to 20 degrees) and shear matrix (-10 to 10 degrees in each X and Y). Assuming that your center coordinates are calculated correctly in corner-based coordinates, the output matrices will produce a mathematically correct result when used with DALI.center = [224,224] m = fn.transforms.rotation(angle=fn.random.uniform(range=(-20,20)), center=center) m = fn.transforms.shear(m, angles=fn.random(range=(-10,10), shape=[2]), center=center)
If you happen to have matrices that prodcue correct result with OpenCV, you can use these operators to adjust the transform (from OpenCV to DALI) - but do that only if you already have the matrices in OpenCV coordinates.
m = # a DataNode with a transform in OpenCV coords m = fn.transforms.translation(m, offset=[-0.5, -0.5], reverse_order=True) # apply translation before m m = fn.transforms.translation(m, offset=[0.5, 0.5]) # apply translation on top of m
Hi @mzient ,many thanks for your help!
I have read your advice carefully. The function is applied for transforming matrix from DALI to CV, which means:
Mcv = ToCVMatrix(Mdali)
That makes sense!
Besides, how should I confirm the center? Is it always be [224,224]? How do you define the coordinate?
Thanks!
m = # a DataNode with a transform in OpenCV coords m = fn.transforms.translation(m, offset=[-0.5, -0.5], reverse_order=True) # apply translation before m m = fn.transforms.translation(m, offset=[0.5, 0.5]) # apply translation on top of m
Besides, where should I put this code segment? I tried several times but it seems improper.
My demo code:
class NumpyReaderPipeline(Pipeline):
def __init__(self, batch_size, device, image_path, files, img_files, num_threads, device_id, seed, path, size, mean,
std, shard_id, num_shards, shuffle, read_ahead):
super(NumpyReaderPipeline, self).__init__(batch_size, num_threads, device_id, seed=seed)
self.input = ops.FileReader(file_root=image_path, files=img_files, shard_id=shard_id, num_shards=num_shards,
seed=seed, random_shuffle=shuffle, read_ahead=read_ahead)
self.decode = ops.ImageDecoder(device='mixed', output_type=types.RGB)
self.np_reader = ops.NumpyReader(device=device,
file_root=path[0],
files=files,
shard_id=shard_id,
random_shuffle=shuffle,
seed=seed,
read_ahead=read_ahead,
num_shards=num_shards)
self.np_reader2 = ops.NumpyReader(device=device,
file_root=path[1],
files=files,
shard_id=shard_id,
random_shuffle=shuffle,
seed=seed,
read_ahead=read_ahead,
num_shards=num_shards)
self.np_reader3 = ops.NumpyReader(device=device,
file_root=path[2],
files=files,
shard_id=shard_id,
random_shuffle=shuffle,
seed=seed,
read_ahead=read_ahead,
num_shards=num_shards)
self.np_reader4 = ops.NumpyReader(device=device,
file_root=path[3],
files=files,
random_shuffle=shuffle,
seed=seed,
shard_id=shard_id,
read_ahead=read_ahead,
num_shards=num_shards)
self.np_reader5 = ops.NumpyReader(device=device,
file_root=path[4],
files=files,
random_shuffle=shuffle,
seed=seed,
shard_id=shard_id,
read_ahead=read_ahead,
num_shards=num_shards)
self.np_reader6 = ops.NumpyReader(device=device,
file_root=path[5],
files=files,
random_shuffle=shuffle,
seed=seed,
read_ahead=read_ahead,
shard_id=shard_id,
num_shards=num_shards)
self.transform_source = ops.ExternalSource()
self.mat = np.zeros([1, 2, 3])
self.mat[0,:,:] = np.mat([[0.75,0,0],[0,0.75,0]])
self.mat = self.mat.astype(np.float32)
self.wa = ops.WarpAffine(device="gpu",
size=(96,96),
inverse_map=False,
interp_type = types.INTERP_LINEAR)
self.cmn = ops.CropMirrorNormalize(device="gpu",
output_dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(size, size),
image_type=types.RGB,
mean=mean,
std=std)
def define_graph(self):
self.transform = self.transform_source()
jpegs, _ = self.input(name="Reader")
images = self.decode(jpegs)
# output = self.cmn(images)
output = self.wa(images, self.transform.gpu())
# return output, self.np_reader(), self.np_reader2(), self.np_reader3(), self.np_reader4(), self.np_reader5(), self.np_reader6()
return output
def iter_setup(self):
# Generate the transforms for the batch and feed them to the ExternalSource
self.feed_input(self.transform, self.mat)
@zhiheng-ldj I encourage you to use the new "functional API" - you don't have to separately create operator objects and invoke them. It's much simpler that way. Your code above would look like:
import nvidia.dali as dali
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.pipeline import Pipeline
import numpy as np
import nvidia.dali.fn as fn
def NumpyReaderPipeline(batch_size, device, image_path, files, img_files, num_threads, device_id, seed, path,
size, mean, std, shard_id, num_shards, shuffle, read_ahead):
pipe = Pipeline(batch_size, num_threads, device_id, seed=seed)
image_files, labels = fn.file_reader(file_root=image_path, files=img_files, shard_id=shard_id, num_shards=num_shards,
seed=seed, random_shuffle=shuffle, read_ahead=read_ahead)
images = fn.image_decoder(image_files, device="mixed")
np_inputs = [fn.numpy_reader(device=device,
file_root=x,
files=files,
shard_id=shard_id,
random_shuffle=shuffle,
seed=seed,
read_ahead=read_ahead,
num_shards=num_shards) for x in path]
mat = np.float32([[0.75,0,0],[0,0.75,0]])
warped = fn.warp_affine(images, matrix=mat,
size=(96,96),
inverse_map=False,
interp_type = types.INTERP_LINEAR)
#output = fm.crop_mirror_normalize(
# warped,
# dtype=types.FLOAT,
# output_layout="CHW",
# crop=(size, size),
# mean=mean,
# std=std)
#pips.set_outputs(output, *np_inputs)
pipe.set_outputs(warped)
return pipe
files = ["1", "2", "3"]
img_files = ["1", "2","3"]
pipe = NumpyReaderPipeline(4, "cpu", [""], files, img_files, 3, 0, 1234, [""]*6, 448, [0.,0.,0.], [1.,1.,1.], 0, 1, True, 1)
pipe.build()
Please note, that you can feed a numpy array (but not numpy mat!) directly to WarpAffine - no need for ExternalSource (which also is greatly simplified now).
DALI can now automatically promote NumPy arrays (as well as PyTorch Tensors and MXNet ndarrays) to constant nodes - you can pass them directly as inputs or argument inputs.
Here's an example where the aforementioned transforms are applied to your mat
:
import nvidia.dali as dali
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.pipeline import Pipeline
import numpy as np
import nvidia.dali.fn as fn
def NumpyReaderPipeline(batch_size, device, image_path, files, img_files, num_threads, device_id, seed, path,
size, mean, std, shard_id, num_shards, shuffle, read_ahead):
pipe = Pipeline(batch_size, num_threads, device_id, seed=seed)
image_files, labels = fn.file_reader(file_root=image_path, files=img_files, shard_id=shard_id, num_shards=num_shards,
seed=seed, random_shuffle=shuffle, read_ahead=read_ahead)
images = fn.image_decoder(image_files, device="mixed")
np_inputs = [fn.numpy_reader(device=device,
file_root=x,
files=files,
shard_id=shard_id,
random_shuffle=shuffle,
seed=seed,
read_ahead=read_ahead,
num_shards=num_shards) for x in path]
mat = np.float32([[0.75,0,0],[0,0.75,0]])
m = fn.transforms.translation(mat, offset=[-0.5, -0.5], reverse_order=True) # apply translation before m
m = fn.transforms.translation(m, offset=[0.5, 0.5]) # apply translation on top of m
warped = fn.warp_affine(images, matrix=m,
size=(96,96),
inverse_map=False,
interp_type = types.INTERP_LINEAR)
#output = fm.crop_mirror_normalize(
# warped,
# dtype=types.FLOAT,
# output_layout="CHW",
# crop=(size, size),
# mean=mean,
# std=std)
#pips.set_outputs(output, *np_inputs)
pipe.set_outputs(warped)
return pipe
files = ["1", "2", "3"]
img_files = ["1", "2","3"]
pipe = NumpyReaderPipeline(4, "cpu", [""], files, img_files, 3, 0, 1234, [""]*6, 448, [0.,0.,0.], [1.,1.,1.], 0, 1, True, 1)
pipe.build()
And one more example, when there's a lambda passed to ExternalSource, which generates one sample at a time (batch=False), returning the same numpy matrix as before:
import nvidia.dali as dali
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.pipeline import Pipeline
import numpy as np
import nvidia.dali.fn as fn
def NumpyReaderPipeline(batch_size, device, image_path, files, img_files, num_threads, device_id, seed, path,
size, mean, std, shard_id, num_shards, shuffle, read_ahead):
pipe = Pipeline(batch_size, num_threads, device_id, seed=seed)
image_files, labels = fn.file_reader(file_root=image_path, files=img_files, shard_id=shard_id, num_shards=num_shards,
seed=seed, random_shuffle=shuffle, read_ahead=read_ahead)
images = fn.image_decoder(image_files, device="mixed")
np_inputs = [fn.numpy_reader(device=device,
file_root=x,
files=files,
shard_id=shard_id,
random_shuffle=shuffle,
seed=seed,
read_ahead=read_ahead,
num_shards=num_shards) for x in path]
mat = fn.external_source(lambda: np.float32([[0.75,0,0],[0,0.75,0]]), batch=False)
m = fn.transforms.translation(mat, offset=[-0.5, -0.5], reverse_order=True) # apply translation before m
m = fn.transforms.translation(m, offset=[0.5, 0.5]) # apply translation on top of m
warped = fn.warp_affine(images, matrix=m,
size=(96,96),
inverse_map=False,
interp_type = types.INTERP_LINEAR)
#output = fm.crop_mirror_normalize(
# warped,
# dtype=types.FLOAT,
# output_layout="CHW",
# crop=(size, size),
# mean=mean,
# std=std)
#pips.set_outputs(output, *np_inputs)
pipe.set_outputs(warped)
return pipe
files = ["1", "2", "3"]
img_files = ["1", "2","3"]
pipe = NumpyReaderPipeline(4, "cpu", [""], files, img_files, 3, 0, 1234, [""]*6, 448, [0.,0.,0.], [1.,1.,1.], 0, 1, True, 1)
pipe.build()
@zhiheng-ldj Your code uses a lot of deprecated features: output_dtype - now called dtype in all operators image_type - no longer used nor necessary (perhaps except image decoder, but RGB is the default) layout specification - now it's a string "CHW", "HWC", without leading N (as sample index is not part of layout of tensors in a tensor list and this dimension is special - it can't be transposed with others, for example).
@zhiheng-ldj Your code uses a lot of deprecated features: output_dtype - now called dtype in all operators image_type - no longer used nor necessary (perhaps except image decoder, but RGB is the default) layout specification - now it's a string "CHW", "HWC", without leading N (as sample index is not part of layout of tensors in a tensor list and this dimension is special - it can't be transposed with others, for example).
Tons of thanks for your help and I have modified my code as you post.
Btw, I noticed that DALI supports direct transformation instead of affine matrix, such as nvidia.dali.ops.transforms.Scale and nvidia.dali.ops.transforms.Rotation. Could you provide a simple demo for showing how to use them?
From my standpoint, it should be more convenient to implement transformation for image.
Thanks for your help!
Btw, I noticed that DALI supports direct transformation instead of affine matrix, such as nvidia.dali.ops.transforms.Scale and nvidia.dali.ops.transforms.Rotation.
These operators manipulate affine transform matrices. All operators in this family work like this:
An (almost) complete example:
pipe = Pipeline(...) # configure your pipeline
jpegs, labels = fn.file_reader(...) # configure your reader here
images = fn.image_decoder(jpegs, device="mixed")
# look up the shape of the input images - you can't use `fn.shapes` on images when using NVJPEG
shapes = fn.peek_image_shape(jpegs) # HWC
w = fn.slice(shapes, 1, 1, axes=0) # extract W
h = fn.slice(shapes, 0, 1, axes=1) # extract H
size = fn.cast(fn.cat(w, h), dali.types.FLOAT) # combine to WH and convert to float
out_width, out_heght = 300, 300 # whatever
center = size/2
angle = fn.random.uniform(range=(-45, 45))
m = fn.transforms.rotation(angle=angle, center=center) # no pre-existing transform
m = fn.transforms.scale(m, scale=[0.8, 1.2], center=center) # anisotropic scaling
m = fn.transforms.translation(m, offset=fn,.random.uniform(range=(-20,20), shape=[2])) # random offset +/- 20 pixels in each axis
# this effectively resizes the image from
m = fn.transforms.crop(m, from_start=[0,0], from_end=size, to_start=[0,0], to_end=[out_width, out_height])
transformed = fn.warp_affine(images, matrix=m, size=[out_height, out_width], inverse_map=False)
pipe.set_outputs(transformed, labels)
I haven't tried to run the code above, so it may contain some errors/typos/..., but you get the general idea.
There's an example here - I think it uses all the transforms that we have: https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/math/geometric_transforms.html#Affine-Transform
You can actually run this example on your machine - it's in our github repo. You'll also need data from DALI_extra
repository to get the data (NOTE: you need GIT LFS to clone DALI_extra!).
Btw, I noticed that DALI supports direct transformation instead of affine matrix, such as nvidia.dali.ops.transforms.Scale and nvidia.dali.ops.transforms.Rotation.
These operators manipulate affine transform matrices. All operators in this family work like this:
- create an (affine transform) matrix according to the named arguments (such as rotation angle and center)
- optionally combine it with with a pre-existing transfrom matrix
An (almost) complete example:
pipe = Pipeline(...) # configure your pipeline jpegs, labels = fn.file_reader(...) # configure your reader here images = fn.image_decoder(jpegs, device="mixed") # look up the shape of the input images - you can't use `fn.shapes` on images when using NVJPEG shapes = fn.peek_image_shape(jpegs) # HWC w = fn.slice(shapes, 1, 1, axes=0) # extract W h = fn.slice(shapes, 0, 1, axes=1) # extract H size = fn.cast(fn.cat(w, h), dali.types.FLOAT) # combine to WH and convert to float out_width, out_heght = 300, 300 # whatever out_size = np.float32([out_width, out_height]) out_shape = [out_height, out_width] center = size/2 angle = fn.random.uniform(range=(-45, 45)) m = fn.transforms.rotation(angle=angle, center=center) # no pre-existing transform m = fn.transforms.scale(m, scale=[0.8, 1.2], center=center) # anisotropic scaling m = fn.transforms.translation(m, offset=fn,.random.uniform(range=(-20,20), shape=[2])) # random offset +/- 20 pixels in each axis m = fn.transforms.crop(m, from_end=size, to_end=out_size) transformed = fn.warp_affine(images, matrix=m, size=out_shape, inverse_map=False) pipe.set_outputs(transformed, labels)
I have tried your idea and code as following:
def __init__:
.......
center = size/2
angle = fn.uniform(range=(-45, 45))
m = fn.transforms.rotation(angle=angle, center=center) # no pre-existing transform
m = fn.transforms.scale(m, scale=[0.8, 1.2], center=center) # anisotropic scaling
self.wa = ops.WarpAffine(device="gpu",
matrix=m,
size=(size, size),
inverse_map=False,
interp_type=types.INTERP_LINEAR)
......
def define_graph:
......
output = self.wa(images)
return output
And it raises an error:
Assert on "arg_in_view.shape.sample_dim() == 1" failed: ``angle`` must be a 1D tensor
How should I fix it?
Besides, is the matrix produced only once in init? How should I produce a matrix for each image? Using for loop in init?
Thanks!!
@zhiheng-ldj
That's a bug we've fixed some time ago. Which version are you using? Updating to latest DALI should help.
If update is not an option, specify angle=fn.unirofm(range=(-45,45), shape=[1]))
.
@zhiheng-ldj That's a bug we've fixed some time ago. Which version are you using? Updating to latest DALI should help. If update is not an option, specify
angle=fn.unirofm(range=(-45,45), shape=[1]))
.
Thanks for your help and it works. (center should be (size/2, size/2))
I will update the version recently. (network seems not good for downloading)
One more question, how should I produce a matrix for each image? Can I do this with a very simple way(instead of constructing a for loop for a batch and feed them to fn.warpaffine)?
Thanks!!
@zhiheng-ldj
Besides, is the matrix produced only once in init? How should I produce a matrix for each image? Using for loop in init?
You don't have to. All operators work with batches, so these operators generate an entire batch of images.
in init
Let me reiterate: you don't have to separate the logic between __init__
and define_graph
- you don't have to inerit from Pipeline
. You can just create a pipeline object, define your processing graph where you see fit and pass the output nodes to pipe.set_outputs
, as I've done in my code snippets. This is the new and preferred way of defining pipelines (in fact, we've plans to simplify it even further).
Having that said, your questions sound a bit alarming - so I'll proactively explain how DALI works:
The operations you define by calling opreators - with fn.some_op(input, arg1=value1, ...)
or with legacy API with instances of ops.SomeOp
- don't perform any data processing - they just define the connections between various operations in the pipeline.
The result of calling a DALI operator is a DataNode (or a tuple of DataNodes, when an operator has multiple outputs). A DataNode is a node in the processing graph node that represents a batch of data produced by an operator.
angle = fn.random.uniform(range=(-45, 45)) # angle is an entire batch of random numbers
m = fn.transform.rotation(angle=angle) # m is an entire batch of matrices
When you set pipeline outputs, the graph is traversed and only the nodes that contribute to the outputs are actually instantiated.
The next step is calling pipe.build()
- this is where the pipeline operators actually instantiated in thte backend.
Finally, calling out = pipe.run()
executes the graph and produces a tuple of output batches.
center should be (size/2, size/2)
Definitely not when the size is what I defined it to be - this would be the case if size
is a scalar constant.
center should be (size/2, size/2)
Definitely not when the size is what I defined it to be - this would be the case if
size
is a scalar constant.
One more question, can these operators such as fn.warp_affine be applied to npy arrays? which is read by ops.NumpyReader?
Thanks! I think it should work?
One more question, can these operators such as fn.warp_affine be applied to npy arrays? which is read by ops.NumpyReader?
When you read the data using NumpyReader it is treated as any other batch of tensors from any other operator, so it should work fine.
One more question, can these operators such as fn.warp_affine be applied to npy arrays? which is read by ops.NumpyReader?
The input format doesn't matter. However, warp_affine expects a specific data layout, with a channel dimension. So if your images are 2D numpy arrays, you'll have to do this:
npy = fn.numpy_reader(...) # 2D arrays
img = fn.reshape(npy, rel_shape=[1,1,-1], layout="HWC")
Explanation:
rel_shape
specifies shape relative to existing dimensions - 1 means input shape *1. The value of -1 is special - it says that all remaining dimensions should be squeezed and put there. We've specified all dimensions, so the remaining a new dimensions of size 1 is appended.
layout="HWC"
sets layout information for the output tensor
reshape
is the cheapest of all operators - it doesn't touch or even copy the data - it produces a tensor with new shape and layout but the data is just another reference to the input.
One more question, can these operators such as fn.warp_affine be applied to npy arrays? which is read by ops.NumpyReader?
The input format doesn't matter. However, warp_affine expects a specific data layout, with a channel dimension. So if your images are 2D numpy arrays, you'll have to do this:
npy = fn.numpy_reader(...) # 2D arrays img = fn.reshape(npy, rel_shape=[1,1,-1], layout="HWC")
Explanation:
rel_shape
specifies shape relative to existing dimensions - 1 means input shape *1. The value of -1 is special - it says that all remaining dimensions should be squeezed and put there. We've specified all dimensions, so the remaining a new dimensions of size 1 is appended.layout="HWC"
sets layout information for the output tensor
reshape
is the cheapest of all operators - it doesn't touch or even copy the data - it produces a tensor with new shape and layout but the data is just another reference to the input.
Hi, thanks for your kindly help! One more theoretical question:
I noticed an affine matrix function written with cv2:
def _get_affine_matrix(self, center, scale, res, rot=0):
# Generate transformation matrix
h = 200 * scale
t = np.zeros((3, 3))
t[0, 0] = float(res[1]) / h
t[1, 1] = float(res[0]) / h
t[0, 2] = res[1] * (-float(center[0]) / h + .5)
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
t[2, 2] = 1
if not rot == 0:
rot = -rot # To match direction of rotation from cropping
rot_mat = np.zeros((3, 3))
rot_rad = rot * np.pi / 180
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
rot_mat[0, :2] = [cs, -sn]
rot_mat[1, :2] = [sn, cs]
rot_mat[2, 2] = 1
# Need to rotate around center
t_mat = np.eye(3)
t_mat[0, 2] = -res[1]/2
t_mat[1, 2] = -res[0]/2
t_inv = t_mat.copy()
t_inv[:2, 2] *= -1
t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
return t
Is it a general form or consistent with DALI?
Thanks!
From my understanding, original matrix should be: [1,0,0;0,1,0].
If I do scale, it should be like [0.5,0,0; 0,0.5,0]. If I do translate, it should be like [1,0,10;0,1,10].
Is it how DALI works? Or my understanding is wrong?
@zhiheng-ldj
Yes; your matrices would be the results of fn.transforms.scale(scale=[0.5,0.5])
and fn.transforms.translation(offset=[10,10])
, respectively.
If you want to know what matrices you're getting, you can temporarily add the matrix as your pipeline's output and print it.
@zhiheng-ldj Yes; your matrices would be the results of
fn.transforms.scale(scale=[0.5,0.5])
andfn.transforms.translation(offset=[10,10])
, respectively. If you want to know what matrices you're getting, you can temporarily add the matrix as your pipeline's output and print it.
Yes, in this case, does coordinate of center matter? Or the center is just a translation operation?
@zhiheng-ldj Translation doesn't require any kind of centering. For scaling - yes, center does matter. If you want to "zoom in" (canvas size stays the same, you just magnify the central part), then you should use real image center. If, on the other hand, you want to scale the image to fill a different canvas size, then center should be at (0,0).
And yes, you can implement centering as two additional translations, but that's less efficient.
Hi, sorry for the delay. I have just spent my Chinese New Year break.
I have done with normal image transformation. However, I have some questions about how to deal with heatmaps.
I am trying to apply affine matrix to my heatmaps . My heatmaps are in shape of (23,48,48). Specifically, 23 is the num of joints and (48,48) is the size of heatmaps.
How should I reshape my heatmaps for applying affine matrix? I tried:
npy_0 = fn.reshape(npy_0, rel_shape=[1,1,1,-1], layout="NHWC")
It doesn't work:
Assert on "static_cast<int>(out_size_f.size()) == spatial_ndim" failed: output_size must specify same number of dimensions as the input (excluding channels)
Thanks for your help!
@zhiheng-ldj
joint_last = fn.transpose(joint_first, perm=[1,2,0], layout="HWC")
After you've transformed your data, you can tranpose it back:
joint_first = fn.transpose(joint_last, perm=[2,0,1], layout="CHW")
npy_0 = fn.reshape(npy_0, rel_shape=[1,1,1,-1], layout="DHWC")
Now your joint is a depth dimension. For that to work, you'd need to construct your matrices as 3D transforms (just keep identity transform in Z axis)
joint_first = fn.transpose(joint_last, perm=[2,0,1], layout="CHW")
failed: Argument "layout" is not supported by operator "Transpose".
Is that a version bug?
Here is my code:
npy_0 = self.np_reader() # 2D arrays
npy_0 = fn.transpose(npy_0, perm=[1,2,0], layout="HWC")
# npy_0 = fn.reshape(npy_0, rel_shape=[1,1,1,-1], layout="DHWC")
npy_0 = self.wa_mask1(npy_0)
npy_0 = fn.transpose(npy_0, perm=[2,0,1], layout="CHW")
Please check the transpose
operator documentation, it is named output_layout
to be precise.
Please check the
transpose
operator documentation, it is namedoutput_layout
to be precise.
Hi, could you point out why this error raises?
I have checked the dimension of npy_0 and it is in shape of (23,24,24).
How should I fix it?
Thanks!
RuntimeError: Critical error in pipeline:
Error when executing CPU operator Transpose, instance name: "__Transpose_12", encountered:
[../dali/operators/generic/transpose/transpose.h:67] Assert on "output_layout_.ndim() == sample_ndim" failed
Stacktrace (10 entries):
[frame 0]: /home/admin/miniconda3/lib/python3.7/site-packages/nvidia/dali/../../nvidia_dali_cuda101.libs/libdali_operators-28d53e7c.so(+0x402961) [0x7f8c1a227961]
[frame 1]: /home/admin/miniconda3/lib/python3.7/site-packages/nvidia/dali/../../nvidia_dali_cuda101.libs/libdali_operators-28d53e7c.so(+0xe699c8) [0x7f8c1ac8e9c8]
[frame 2]: /home/admin/miniconda3/lib/python3.7/site-packages/nvidia/dali/../../nvidia_dali_cuda101.libs/libdali_operators-28d53e7c.so(+0xe69ba2) [0x7f8c1ac8eba2]
[frame 3]: /home/admin/miniconda3/lib/python3.7/site-packages/nvidia/dali/../../nvidia_dali_cuda101.libs/libdali-d644bc55.so(+0xa0522) [0x7f8c3d522522]
[frame 4]: /home/admin/miniconda3/lib/python3.7/site-packages/nvidia/dali/../../nvidia_dali_cuda101.libs/libdali-d644bc55.so(+0xa1e0a) [0x7f8c3d523e0a]
[frame 5]: /home/admin/miniconda3/lib/python3.7/site-packages/nvidia/dali/../../nvidia_dali_cuda101.libs/libdali-d644bc55.so(+0x854c0) [0x7f8c3d5074c0]
[frame 6]: /home/admin/miniconda3/lib/python3.7/site-packages/nvidia/dali/../../nvidia_dali_cuda101.libs/libdali-d644bc55.so(+0xf81ec) [0x7f8c3d57a1ec]
[frame 7]: /home/admin/miniconda3/lib/python3.7/site-packages/torch/lib/../../../.././libstdc++.so.6(+0xc819d) [0x7f8d092d419d]
[frame 8]: /lib/x86_64-linux-gnu/libpthread.so.0(+0x76db) [0x7f8d223f16db]
[frame 9]: /lib/x86_64-linux-gnu/libc.so.6(clone+0x3f) [0x7f8d2211a88f]
Current pipeline object is no longer valid.
@zhiheng-ldj I investigated the problem and, it turns out, there's a bug in Transpose - namely, specifying output_layout doesn't work when input has no layout info. A workaround is to specify some layout (with proper number of dimensions) first.
data = np.zeros([24, 48, 48], dtype=np.float32)
npy_0 = fn.external_source([data], batch=False)
npy_0 = fn.reshape(npy_0, layout="CHW") # <------------------------------ set a layout first
npy_0 = fn.transpose(npy_0, perm=[1,2,0], output_layout="HWC")
npy_0 = fn.warp_affine(npy_0, matrix=mtx)
npy_0 = fn.transpose(npy_0, perm=[2,0,1], output_layout="CHW")
We'll roll a fix for this, but until that time you can use the workaround, as in the example above.
@zhiheng-ldj I investigated the problem and, it turns out, there's a bug in Transpose - namely, specifying output_layout doesn't work when input has no layout info. A workaround is to specify some layout (with proper number of dimensions) first.
data = np.zeros([24, 48, 48], dtype=np.float32) npy_0 = fn.external_source([data], batch=False) npy_0 = fn.reshape(npy_0, layout="CHW") # <------------------------------ set a layout first npy_0 = fn.transpose(npy_0, perm=[1,2,0], output_layout="HWC") npy_0 = fn.warp_affine(npy_0, matrix=mtx) npy_0 = fn.transpose(npy_0, perm=[2,0,1], output_layout="CHW")
We'll roll a fix for this, but until that time you can use the workaround, as in the example above.
Hi mzient,
Thanks for your help and I have known about it.
I have checked my output again. It seems that when I save a np array of (23,24,24), if I load it from npreader, a np array of (1,23,24,24) will be got., which means I should have a 4D array.
If so, I should do it with 3D transformation matrix or squeeze it in pipeline? Which will be easier?
As you said:
Option 1: Squeeze it in pipeline.
npy_0 = self.np_reader().squeeze() # 4D arrays -> 3D, i tried but it doesn't work.
npy_0 = fn.reshape(npy_0, layout="CHW")
npy_0 = fn.transpose(npy_0, perm=[1,2,0], output_layout="HWC")
npy_0 = self.wa_mask1(npy_0)
npy_0 = fn.transpose(npy_0, perm=[2,0,1], output_layout="CHW")
Option 2: Do it with 3D matrix.
How should I convert the produced matrix to 3D.
npy_0 = self.np_reader() # 4D arrays
npy_0 = fn.reshape(npy_0, layout="DCHW")
npy_0 = fn.transpose(npy_0, perm=[0,2,3,1], output_layout="DHWC")
npy_0 = self.wa_mask1(npy_0) # 2D matrix to 3D
npy_0 = fn.transpose(npy_0, perm=[0,3,1,2], output_layout="DCHW")
Which is more convenient? Thanks!!
Besides, I have another question:
Can I define a function which can be called in pipeline?
For example:
npy_0 = self.np_reader() # 4D arrays
npy_0 = external_function(npy_0)
I tried once but it would hurt speed performance very seriously (external function may have many for loops). Can we overcome the problem?
Thanks!
Currently there's no "squeeze" operator, although we do have plans to add it in near future. Currently, you can squeeze it like this:
npy_0 = self.np_reader()
# input XCHW where x is the redundant dimension
# desired layout HWCX
# permutation to achieve this is 2, 3, 1, 0
npy_0 = fn.transpose(npy_0, perm=[2,3,1,0])
npy_0 = fn.reshape(npy_0, rel_shape=[1,1,1], layout="CHW") # drop trailing dimension
npy_0 = self.wa_mask1(npy_0)
npy_0 = fn.transpose(npy_0, perm=[2,0,1], output_layout="CHW")
Regarding PythonFunction - yes, that's slow and mostly intended for prototyping. We're exploring the area, but there's no definitive solution yet. If your input data is not very large, you can load it using external source and do custom processing there:
def load_sample(sample_info):
idx = my_dataset_shuffling_function(sample_info.idx_in_epoch)
data = my_magic_dataset_load(idx)
return my_magic_processing(data)
custom_data = fn.external_source(load_sample, batch=False)
Hi,
Thanks for your help.
I discovered that the output of ops.WarpAffine is not equal to cv2.warpAffine with the same image decode.
Specifically:
Should they be the same?
The affine matrix and outputs of decoder is same as expected.
When I print the output after affine:
Should they be the same? Or I missed some params?
Thanks for your help!