22
33import torch
44from torch import Tensor
5- from torchao .utils import TorchAOBaseTensor
5+ from torch .utils ._python_dispatch import return_and_correct_aliasing
6+ from torchao .utils import TorchAOBaseTensor , TORCH_VERSION_AT_LEAST_2_4
67
78from .quant_utils import create_dynamic_map , scale_tensor , quantize_4bit_with_qmap , dequant_with_qmap
89
@@ -60,8 +61,9 @@ def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=No
6061 def dequantize (self , output_dtype = None ):
6162 codes = torch .stack ([self .codes >> 4 , self .codes & 0b1111 ], dim = - 1 ) # unpack
6263 float_data = dequant_with_qmap (codes , self .qmap , self .scale )
63- dtype = output_dtype or torch .get_default_dtype ()
64- return float_data .view (self ._shape ).to (dtype )
64+ if output_dtype is not None :
65+ float_data = float_data .to (output_dtype )
66+ return float_data .view (self ._shape )
6567
6668 @classmethod
6769 def zeros (cls , shape , signed : bool = True , block_size : int = 128 , device = None ):
@@ -80,6 +82,24 @@ def __repr__(self):
8082 )
8183
8284
85+ # in pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when
86+ # dtype is the same but device is different. thus, we must override .to() method instead.
87+ if not TORCH_VERSION_AT_LEAST_2_4 :
88+ def _to (self , * args , ** kwargs ):
89+ # ignore other args/kwargs
90+ device = kwargs .pop ("device" , None )
91+ return OptimState4bit (
92+ self .codes .to (device ),
93+ self .scale .to (device ),
94+ self .qmap .to (device ),
95+ self .signed ,
96+ self .shape ,
97+ )
98+
99+ OptimState4bit .to = _to
100+ del _to # make sure to not re-use
101+
102+
83103@OptimState4bit .implements (aten .copy_ .default )
84104def _ (func , types , args , kwargs ):
85105 dst = args [0 ]
@@ -107,6 +127,20 @@ def _(func, types, args, kwargs):
107127 return dst
108128
109129
130+ @OptimState4bit .implements (aten ._to_copy .default )
131+ def _ (func , types , args , kwargs ):
132+ # ignore dtype
133+ device = kwargs .get ("device" , None )
134+ out = OptimState4bit (
135+ args [0 ].codes .to (device = device ),
136+ args [0 ].scale .to (device = device ),
137+ args [0 ].qmap .to (device = device ),
138+ args [0 ].signed ,
139+ args [0 ].shape ,
140+ )
141+ return return_and_correct_aliasing (func , args , kwargs , out )
142+
143+
110144@OptimState4bit .implements (aten .lerp .Scalar )
111145def _ (func , types , args , kwargs ):
112146 args = [x .dequantize () if isinstance (x , OptimState4bit ) else x for x in args ]
0 commit comments