Add axis to the batched add reduce#1131
Add axis to the batched add reduce#1131jfix71 merged 1 commit intopytorch:masterfrom jfix71:add_reduce_sum
Conversation
|
Thanks for doing this work Jordan. This PR/suggestion has two parts. First, extend the batched-add operator to support non zero dimension of reduction (support the ability to select which dimension we perform the reduction on). And second, the specific implementation of the operator. I think that #1 makes sense. We need to support this kind of operation and adding the axis/dimenstion argument to the operator makes sense. About #2, I don't have a strong opinion here. After all, this decision is constrained to the scope of one function (per-backend). So, even if we make a horrible mistake it will be easy to fix. Over all, it looks like a good direction to me. |
|
I've added support for the CPU backend and for quantization. I skipped OpenCL for now. From looking at the generated LLVM IR it looks like vectorization is still occurring. For quantization, I needed to add different cases because we need the inner loop to do all accumulation in a local variable with more precision before clipping it back down. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
For seq2seq, we need to add in an axis for batched reduce add. This means the current simple implementation is no longer sufficient, since we need to reduce in dimensions other than the first one. I am wondering if the implementation I have here in InterpreterNodes for
fwdBatchedReduceAddInst()is a preferable direction compared to the other options.The problem is that we don't know the number of dimensions of the tensor that we want to iterate over ahead of time. To handle this in other cases, we generally do one of two things. One, write multiple cases for each different number of dimensions of different loop nest depths (e.g. in
tryTransposeFastImpl(),libjit_transpose_generic(),libjit_insert_tensor()). Or two, have a generic recursive version (e.g. intransposeGenericImpl(),insertTensorsImpl()), which might not get great performance and is much less readable/understandable. (I had implemented a third option for broadcast that iterated over generic shapes but I removed it once we removed theBroadcastInst, and I don't think it had great perf anyway.)Instead, the approach I took here was to get an unowned view of both the source batch Tensor and the dest Tensor with expanded dimensions up to
max_tensor_dimensions, with the newly added dimensions = 1. This allows us to have a single loop nest of depthmax_tensor_dimensions. This should enable good perf since it consists of relatively affine accesses/loops, doesn't requirendifferent cases for each of the different number of dimensions, and is still pretty readable IMO. I was thinking we could possibly move toward this sort of implementation in libjit too -- I think it's more readable/maintainable, and post specialization I would imagine would have the same performance (?).What do you all think?