2016-03-23 11 views
5

Dahili hesaplamaları için bir java nesnesi kullanan pyspark python'da kullanılacak bir UDF oluşturmam gerekiyor.Bir java UDF uygulayın ve bunu pyspark'tan arayın

ben böyle bir şey yapar basit bir piton olsaydı:

def f(x): 
    return 7 
fudf = pyspark.sql.functions.udf(f,pyspark.sql.types.IntegerType()) 

ve kullanma diyoruz:

df = sqlContext.range(0,5) 
df2 = df.withColumn("a",fudf(df.id)).show() 

Ancak, gerek fonksiyonunun uygulaması içinde java olup piton. Bir şekilde sarmalıyım, böylece python'dan benzer şekilde onu arayabilirim.

İlk denemem java nesnesini uygulamak, daha sonra pyspark'a python içinde sarmak ve bunu UDF'ye dönüştürmek oldu. Bu serileştirme hatasıyla başarısız oldu.

Java kodu:

package com.test1.test2; 

public class TestClass1 { 
    Integer internalVal; 
    public TestClass1(Integer val1) { 
     internalVal = val1; 
    } 
    public Integer do_something(Integer val) { 
     return internalVal; 
    }  
} 

pyspark kodu:

from py4j.java_gateway import java_import 
from pyspark.sql.functions import udf 
from pyspark.sql.types import IntegerType 
java_import(sc._gateway.jvm, "com.test1.test2.TestClass1") 
a = sc._gateway.jvm.com.test1.test2.TestClass1(7) 
audf = udf(a,IntegerType()) 

hatası:

--------------------------------------------------------------------------- 
Py4JError         Traceback (most recent call last) 
<ipython-input-2-9756772ab14f> in <module>() 
     4 java_import(sc._gateway.jvm, "com.test1.test2.TestClass1") 
     5 a = sc._gateway.jvm.com.test1.test2.TestClass1(7) 
----> 6 audf = udf(a,IntegerType()) 

/usr/local/spark/python/pyspark/sql/functions.py in udf(f, returnType) 
    1595  [Row(slen=5), Row(slen=3)] 
    1596  """ 
-> 1597  return UserDefinedFunction(f, returnType) 
    1598 
    1599 blacklist = ['map', 'since', 'ignore_unicode_prefix'] 

/usr/local/spark/python/pyspark/sql/functions.py in __init__(self, func, returnType, name) 
    1556   self.returnType = returnType 
    1557   self._broadcast = None 
-> 1558   self._judf = self._create_judf(name) 
    1559 
    1560  def _create_judf(self, name): 

/usr/local/spark/python/pyspark/sql/functions.py in _create_judf(self, name) 
    1565   command = (func, None, ser, ser) 
    1566   sc = SparkContext.getOrCreate() 
-> 1567   pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) 
    1568   ctx = SQLContext.getOrCreate(sc) 
    1569   jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) 

/usr/local/spark/python/pyspark/rdd.py in _prepare_for_python_RDD(sc, command, obj) 
    2297  # the serialized command will be compressed by broadcast 
    2298  ser = CloudPickleSerializer() 
-> 2299  pickled_command = ser.dumps(command) 
    2300  if len(pickled_command) > (1 << 20): # 1M 
    2301   # The broadcast will have same life cycle as created PythonRDD 

/usr/local/spark/python/pyspark/serializers.py in dumps(self, obj) 
    426 
    427  def dumps(self, obj): 
--> 428   return cloudpickle.dumps(obj, 2) 
    429 
    430 

/usr/local/spark/python/pyspark/cloudpickle.py in dumps(obj, protocol) 
    644 
    645  cp = CloudPickler(file,protocol) 
--> 646  cp.dump(obj) 
    647 
    648  return file.getvalue() 

/usr/local/spark/python/pyspark/cloudpickle.py in dump(self, obj) 
    105   self.inject_addons() 
    106   try: 
--> 107    return Pickler.dump(self, obj) 
    108   except RuntimeError as e: 
    109    if 'recursion' in e.args[0]: 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in dump(self, obj) 
    222   if self.proto >= 2: 
    223    self.write(PROTO + chr(self.proto)) 
--> 224   self.save(obj) 
    225   self.write(STOP) 
    226 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj) 
    284   f = self.dispatch.get(t) 
    285   if f: 
--> 286    f(self, obj) # Call unbound method with explicit self 
    287    return 
    288 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_tuple(self, obj) 
    566   write(MARK) 
    567   for element in obj: 
--> 568    save(element) 
    569 
    570   if id(obj) in memo: 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj) 
    284   f = self.dispatch.get(t) 
    285   if f: 
--> 286    f(self, obj) # Call unbound method with explicit self 
    287    return 
    288 

/usr/local/spark/python/pyspark/cloudpickle.py in save_function(self, obj, name) 
    191   if islambda(obj) or obj.__code__.co_filename == '<stdin>' or themodule is None: 
    192    #print("save global", islambda(obj), obj.__code__.co_filename, modname, themodule) 
--> 193    self.save_function_tuple(obj) 
    194    return 
    195   else: 

/usr/local/spark/python/pyspark/cloudpickle.py in save_function_tuple(self, func) 
    234   # create a skeleton function object and memoize it 
    235   save(_make_skel_func) 
--> 236   save((code, closure, base_globals)) 
    237   write(pickle.REDUCE) 
    238   self.memoize(func) 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj) 
    284   f = self.dispatch.get(t) 
    285   if f: 
--> 286    f(self, obj) # Call unbound method with explicit self 
    287    return 
    288 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_tuple(self, obj) 
    552   if n <= 3 and proto >= 2: 
    553    for element in obj: 
--> 554     save(element) 
    555    # Subtle. Same as in the big comment below. 
    556    if id(obj) in memo: 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj) 
    284   f = self.dispatch.get(t) 
    285   if f: 
--> 286    f(self, obj) # Call unbound method with explicit self 
    287    return 
    288 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_list(self, obj) 
    604 
    605   self.memoize(obj) 
--> 606   self._batch_appends(iter(obj)) 
    607 
    608  dispatch[ListType] = save_list 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in _batch_appends(self, items) 
    637     write(MARK) 
    638     for x in tmp: 
--> 639      save(x) 
    640     write(APPENDS) 
    641    elif n: 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj) 
    304    reduce = getattr(obj, "__reduce_ex__", None) 
    305    if reduce: 
--> 306     rv = reduce(self.proto) 
    307    else: 
    308     reduce = getattr(obj, "__reduce__", None) 

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in __call__(self, *args) 
    811   answer = self.gateway_client.send_command(command) 
    812   return_value = get_return_value(
--> 813    answer, self.gateway_client, self.target_id, self.name) 
    814 
    815   for temp_arg in temp_args: 

/usr/local/spark/python/pyspark/sql/utils.py in deco(*a, **kw) 
    43  def deco(*a, **kw): 
    44   try: 
---> 45    return f(*a, **kw) 
    46   except py4j.protocol.Py4JJavaError as e: 
    47    s = e.java_exception.toString() 

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name) 
    310     raise Py4JError(
    311      "An error occurred while calling {0}{1}{2}. Trace:\n{3}\n". 
--> 312      format(target_id, ".", name, value)) 
    313   else: 
    314    raise Py4JError(

Py4JError: An error occurred while calling o18.__getnewargs__. Trace: 
py4j.Py4JException: Method __getnewargs__([]) does not exist 
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:335) 
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:344) 
    at py4j.Gateway.invoke(Gateway.java:252) 
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:133) 
    at py4j.commands.CallCommand.execute(CallCommand.java:79) 
    at py4j.GatewayConnection.run(GatewayConnection.java:209) 
    at java.lang.Thread.run(Thread.java:745) 

DÜZENLEME: Ben de ama boşuna serializable java sınıfını yapmaya çalıştı .

Benim ikinci girişimi ile başlayacak java UDF belirlenmeye çalışıldı ama bu doğru kaydırmak için nasıl emin değilim olarak başarısız oldu:

java kodu: paket com.test1.test2;

import org.apache.spark.sql.api.java.UDF1; 

public class TestClassUdf implements UDF1<Integer, Integer> { 

    Integer retval; 

    public TestClassUdf(Integer val) { 
     retval = val; 
    } 

    @Override 
    public Integer call(Integer arg0) throws Exception { 
     return retval; 
    } 
} 

ama nasıl kullanırdım? ----- : Denedim:

from py4j.java_gateway import java_import 
java_import(sc._gateway.jvm, "com.test1.test2.TestClassUdf") 
a = sc._gateway.jvm.com.test1.test2.TestClassUdf(7) 
dfint = sqlContext.range(0,15) 
df = dfint.withColumn("a",a(dfint.id)) 

ama olsun:

--------------------------------------------------------------------------- 
TypeError         Traceback (most recent call last) 
<ipython-input-5-514811090b5f> in <module>() 
     3 a = sc._gateway.jvm.com.test1.test2.TestClassUdf(7) 
     4 dfint = sqlContext.range(0,15) 
----> 5 df = dfint.withColumn("a",a(dfint.id)) 

TypeError: 'JavaObject' object is not callable 

ve ben bir a.call yerine kullanmaya çalıştı:

df = dfint.withColumn("a",a.call(dfint.id)) 

ama var -------------------------------------------------- -------------------- TypeError Traceback (son çağrı son) () 3 a = sc._gateway.jvm.com.test1.test2.TestClassUdf (7) 4 dfint = sqlContext.range (0,15) ----> 5 df = dfint.withColumn ("a", bir .call (dfint.id))

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in __call__(self, *args) 
    796  def __call__(self, *args): 
    797   if self.converters is not None and len(self.converters) > 0: 
--> 798    (new_args, temp_args) = self._get_args(args) 
    799   else: 
    800    new_args = args 

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in _get_args(self, args) 
    783     for converter in self.gateway_client.converters: 
    784      if converter.can_convert(arg): 
--> 785       temp_arg = converter.convert(arg, self.gateway_client) 
    786       temp_args.append(temp_arg) 
    787       new_args.append(temp_arg) 

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_collections.py in convert(self, object, gateway_client) 
    510   HashMap = JavaClass("java.util.HashMap", gateway_client) 
    511   java_map = HashMap() 
--> 512   for key in object.keys(): 
    513    java_map[key] = object[key] 
    514   return java_map 

TypeError: 'Column' object is not callable 

Herhangi bir yardım istenir.

cevap

3

Bunu, UDAF'lerle ilgili olarak another question (and answer) of your own yardımıyla çalışıyorum.

Spark, Scala FunctionN'u sarmak için udf() yöntemini sağlar, böylece Java işlevini Scala'ya sarmak ve bunu kullanmak. Java yönteminizin statik veya implements Serializable sınıfında olması gerekir.PySpark içinde

package com.example 

import org.apache.spark.sql.UserDefinedFunction 
import org.apache.spark.sql.functions.udf 

class MyUdf extends Serializable { 
    def getUdf: UserDefinedFunction = udf(() => MyJavaClass.MyJavaMethod()) 
} 

Kullanımı: Diğer sorunuza ve yanıtında UDAF olduğu gibi

def my_udf(): 
    from pyspark.sql.column import Column, _to_java_column, _to_seq 
    pcls = "com.example.MyUdf" 
    jc = sc._jvm.java.lang.Thread.currentThread() \ 
     .getContextClassLoader().loadClass(pcls).newInstance().getUdf().apply 
    return Column(jc(_to_seq(sc, [], _to_java_column))) 

rdd1 = sc.parallelize([{'c1': 'a'}, {'c1': 'b'}, {'c1': 'c'}]) 
df1 = rdd1.toDF() 
df2 = df1.withColumn('mycol', my_udf()) 

, biz return Column(jc(_to_seq(sc, ["col1", "col2"], _to_java_column)))

ile içine sütunları geçebilir